# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger from typing import Dict from fusion_base import Fusion from numpy import ndarray from onnx import helper from onnx_model import OnnxModel logger = getLogger(__name__) class FusionBiasAdd(Fusion): def __init__(self, model: OnnxModel): super().__init__(model, "BiasAdd", "Add") def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): """ Fuse Add bias and Add skip connection into BiasAdd """ nodes = self.model.match_parent_path( add_node, ["Add", "MatMul", "BiasSplitGelu", "MatMul", "SkipLayerNormalization"], [0, None, 0, 0, 0], output_name_to_node, ) if nodes is None: return bias_node = nodes[0] skip_layer_norm = nodes[-1] # Check skip connection is from SkipLayerNormalization output if add_node.input[1] not in skip_layer_norm.output: return bias_index, bias_value = self.model.get_constant_input(bias_node) if not (isinstance(bias_index, int) and (bias_value is not None) and isinstance(bias_value, ndarray)): return if bias_value.ndim != 1: return self.nodes_to_remove.extend([add_node, bias_node]) node_name = self.model.create_node_name("BiasAdd") fused_node = helper.make_node( "BiasAdd", inputs=[bias_node.input[1 - bias_index], bias_node.input[bias_index], add_node.input[1]], outputs=[add_node.output[0]], name=node_name, ) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) self.node_name_to_graph_name[node_name] = self.this_graph_name
Memory