# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import Dict
from fusion_base import Fusion
from numpy import ndarray
from onnx import helper
from onnx_model import OnnxModel
logger = getLogger(__name__)
class FusionBiasAdd(Fusion):
def __init__(self, model: OnnxModel):
super().__init__(model, "BiasAdd", "Add")
def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
"""
Fuse Add bias and Add skip connection into BiasAdd
"""
nodes = self.model.match_parent_path(
add_node,
["Add", "MatMul", "BiasSplitGelu", "MatMul", "SkipLayerNormalization"],
[0, None, 0, 0, 0],
output_name_to_node,
)
if nodes is None:
return
bias_node = nodes[0]
skip_layer_norm = nodes[-1]
# Check skip connection is from SkipLayerNormalization output
if add_node.input[1] not in skip_layer_norm.output:
return
bias_index, bias_value = self.model.get_constant_input(bias_node)
if not (isinstance(bias_index, int) and (bias_value is not None) and isinstance(bias_value, ndarray)):
return
if bias_value.ndim != 1:
return
self.nodes_to_remove.extend([add_node, bias_node])
node_name = self.model.create_node_name("BiasAdd")
fused_node = helper.make_node(
"BiasAdd",
inputs=[bias_node.input[1 - bias_index], bias_node.input[bias_index], add_node.input[1]],
outputs=[add_node.output[0]],
name=node_name,
)
fused_node.domain = "com.microsoft"
self.nodes_to_add.append(fused_node)
self.node_name_to_graph_name[node_name] = self.this_graph_name