# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger from fusion_base import Fusion from fusion_utils import NumpyHelper from onnx import helper from onnx_model import OnnxModel logger = getLogger(__name__) class FusionSkipLayerNormalization(Fusion): """ Fuse Add + LayerNormalization into one node: SkipLayerNormalization Note: This fusion does not check the input shape of Add and LayerNormalization. """ def __init__( self, model: OnnxModel, fused_op_type: str = "SkipLayerNormalization", search_op_types: str = "LayerNormalization", shape_infer: bool = True, ): super().__init__(model, fused_op_type, search_op_types) if shape_infer: # Update shape inference is needed since other fusions might add new edge which does not have shape info yet. self.shape_infer_helper = self.model.infer_runtime_shape({"batch_size": 4, "seq_len": 7}, update=True) if self.shape_infer_helper is None: # TODO(tianleiwu): support subgraph in shape inference or add broadcasting in SkipLayerNormalization op. logger.warning("symbolic shape inference disabled or failed.") def fuse(self, node, input_name_to_nodes, output_name_to_node): add = self.model.get_parent(node, 0, output_name_to_node) # In some models there is input_ids->gather->add->LayerNorm and one of input of the # add node is initializer with fixed shape which should not be fused into SkipLayerNorm if add is None or add.op_type != "Add": return # The number of inputs of add should be 2 if len(add.input) != 2: return for add_input in add.input: if self.model.get_initializer(add_input) is not None: return # To avoid an Add node have two children of LayerNormalization, we shall only fuse one SkipLayerNormalization if add in self.nodes_to_remove: return # Root Mean Square Layer Normalization simplified = node.op_type == "SimplifiedLayerNormalization" if hasattr(self, "shape_infer_helper"): if self.shape_infer_helper is not None: if ( self.shape_infer_helper.get_edge_shape(add.input[0]) and len(self.shape_infer_helper.get_edge_shape(add.input[0])) != 3 ): logger.debug("skip SkipLayerNormalization fusion since shape of input %s is not 3D", add.input[0]) return # TODO(tianleiwu): support broadcasting Skip shape (1, sequence_length, hidden_size) or (sequence_length, hidden_size) if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]): logger.debug( "skip SkipLayerNormalization fusion since shape of inputs (%s, %s) are not same", add.input[0], add.input[1], ) return else: logger.debug("skip SkipLayerNormalization fusion since symbolic shape inference failed") return gather_path = self.model.match_parent_path(add, ["Gather"], [None]) if gather_path is not None and self.model.find_graph_input(gather_path[0].input[1]) is None: if self.model.match_parent_path(gather_path[0], ["ConstantOfShape"], [1]) is None: return # This means that the residual Add before the LayerNormalization produces an output # that is consumed by some other nodes or graph output other than the LayerNormalization itself # We can still go ahead with the SkipLayerNormalization fusion but we need to # preserve the output of Add and that needs to be produced by SkipLayerNormalization. add_has_graph_output = self.model.find_graph_output(add.output[0]) is not None residual_add_has_multiple_consumers = ( add_has_graph_output or len(self.model.get_children(add, input_name_to_nodes)) > 1 ) outputs_to_keep = node.output if residual_add_has_multiple_consumers: outputs_to_keep.extend([add.output[0]]) outputs = [node.output[0]] # Skip the other optional outputs of SkipLayerNormalization before adding the Add's output if residual_add_has_multiple_consumers: outputs.extend(["", "", add.output[0]]) if self.model.is_safe_to_fuse_nodes([add, node], outputs_to_keep, input_name_to_nodes, output_name_to_node): self.nodes_to_remove.extend([add, node]) inputs = ( [add.input[0], add.input[1], node.input[1], node.input[2]] if not simplified else [add.input[0], add.input[1], node.input[1]] ) normalize_node = helper.make_node( self.fused_op_type, inputs=inputs, outputs=outputs, name=self.model.create_node_name(self.fused_op_type, name_prefix="SkipLayerNorm"), ) normalize_node.domain = "com.microsoft" # Pass attribute "epsilon" from layernorm node to SkipLayerNormalization for att in node.attribute: if att.name == "epsilon": normalize_node.attribute.extend([att]) # Set default epsilon if no epsilon exists from layernorm if len(normalize_node.attribute) == 0: normalize_node.attribute.extend([helper.make_attribute("epsilon", 1.0e-12)]) self.nodes_to_add.append(normalize_node) self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name class FusionBiasSkipLayerNormalization(Fusion): def __init__(self, model: OnnxModel): super().__init__(model, "SkipLayerNormalization", "SkipLayerNormalization", "add bias") def fuse(self, node, input_name_to_nodes, output_name_to_node): if len(node.input) != 4: return return_indice = [] nodes = self.model.match_parent_path(node, ["Add", "MatMul"], [None, None], output_name_to_node, return_indice) if nodes is not None: (add, _matmul) = nodes else: # In case of fp16, we could have a Cast between the MatMul and the bias Add return_indice = [] nodes = self.model.match_parent_path( node, ["Add", "Cast", "MatMul"], [None, None, None], output_name_to_node, return_indice ) if nodes is not None: (add, _cast, _matmul) = nodes else: return assert len(return_indice) == 2 or len(return_indice) == 3 add_input_index = return_indice[0] if add_input_index >= 2: return sln_input = add.input[return_indice[1]] bias_input = add.input[1 - return_indice[1]] skip_input = node.input[1 - add_input_index] # bias should be one dimension initializer = self.model.get_initializer(bias_input) if initializer is None: return bias_weight = NumpyHelper.to_array(initializer) if bias_weight is None: logger.debug("Bias weight not found") return if len(bias_weight.shape) != 1: logger.debug("Bias weight is not 1D") return subgraph_nodes = [node, add] if not self.model.is_safe_to_fuse_nodes(subgraph_nodes, node.output, input_name_to_nodes, output_name_to_node): logger.debug("Skip fusing SkipLayerNormalization with Bias since it is not safe") return self.nodes_to_remove.extend(subgraph_nodes) inputs = [ sln_input, skip_input, node.input[2], node.input[3], bias_input, ] new_node = helper.make_node( "SkipLayerNormalization", inputs=inputs, outputs=node.output, name=self.model.create_node_name("SkipLayerNormalization", "SkipLayerNorm_AddBias_"), ) new_node.domain = "com.microsoft" # Pass attribute "epsilon" from skiplayernorm node to skiplayernorm(add bias) for att in node.attribute: if att.name == "epsilon": new_node.attribute.extend([att]) # Set default epsilon if no epsilon exists from skiplayernorm if len(new_node.attribute) == 0: new_node.attribute.extend([helper.make_attribute("epsilon", 1.0e-12)]) self.nodes_to_add.append(new_node) self.node_name_to_graph_name[new_node.name] = self.this_graph_name
Memory