# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- from __future__ import annotations import onnx from ..onnx_model import ONNXModel from .fusion import Fusion class FusionLayerNormalization(Fusion): def __init__(self, model: ONNXModel): super().__init__(model, "LayerNormalization", "ReduceMean") def fuse( self, reduce_mean_node: onnx.NodeProto, input_name_to_nodes: dict[str, list[onnx.NodeProto]], output_name_to_node: dict[str, onnx.NodeProto], ): """ Interface function that tries to fuse a node sequence containing a ReduceMean node into a single LayerNormalization node. +----------------------+ | | | v [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add (axis=2 or -1) | (Y=2) (axis=2 or -1) (E-6 or E-12 or 0) ^ | | +-------------------------------------------------+ It also handles cases of duplicated sub nodes exported from older version of PyTorch: +----------------------+ | v | +-------> Sub-----------------------------------------------+ | | | | | v [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add | ^ | | +----------------------+ """ children = self.model.get_children(reduce_mean_node, input_name_to_nodes) if len(children) == 0 or len(children) > 2: return root_input = reduce_mean_node.input[0] if children[0].op_type != "Sub" or children[0].input[0] != root_input: return if len(children) == 2: if children[1].op_type != "Sub" or children[1].input[0] != root_input: return div_node = None for child in children: div_node = self.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False) if div_node is not None: break if div_node is None: return path_id, parent_nodes, _ = self.match_parent_paths( div_node, [ (["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]), ( ["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"], [1, 0, 0, 0, 0, 0], ), ], output_name_to_node, ) if path_id < 0: return sub_node = parent_nodes[-1] if sub_node not in children: return second_add_node = parent_nodes[1] i, add_weight = self.get_constant_input(second_add_node) if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4: # Skip fusion since epsilon value is not expected. return pow_node = parent_nodes[3] if self.find_constant_input(pow_node, 2.0) != 1: return mul_node = input_name_to_nodes[div_node.output[0]][0] if mul_node.op_type != "Mul": return last_add_node = input_name_to_nodes[mul_node.output[0]][0] if last_add_node.op_type != "Add": return subgraph_nodes = [reduce_mean_node] subgraph_nodes.extend(children) subgraph_nodes.extend(parent_nodes[:-1]) subgraph_nodes.extend([last_add_node, mul_node, div_node]) if not self.is_safe_to_fuse_nodes( subgraph_nodes, last_add_node.output, input_name_to_nodes, output_name_to_node, ): return weight_input = mul_node.input[1 - self.input_index(div_node.output[0], mul_node)] if not self.is_constant_with_specified_rank(weight_input, 1): return bias_input = last_add_node.input[1 - self.input_index(mul_node.output[0], last_add_node)] if not self.is_constant_with_specified_rank(bias_input, 1): return self.nodes_to_remove.extend(subgraph_nodes) normalize_node = onnx.helper.make_node( "LayerNormalization", name=self.create_unique_node_name(), inputs=[reduce_mean_node.input[0], weight_input, bias_input], outputs=[last_add_node.output[0]], ) normalize_node.attribute.extend([onnx.helper.make_attribute("epsilon", float(add_weight))]) self.nodes_to_add.append(normalize_node)
Memory