# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger from typing import Dict, List from fusion_base import Fusion from onnx import TensorProto, helper from onnx_model import OnnxModel logger = getLogger(__name__) class FusionLayerNormalization(Fusion): def __init__(self, model: OnnxModel): super().__init__(model, "LayerNormalization", "ReduceMean") def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): """ Fuse Layer Normalization subgraph into one node LayerNormalization: +----------------------+ | | | v [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add (axis=2 or -1) | (Y=2) (axis=2 or -1) (E-6 or E-12 or 0) ^ | | +-----------------------------------------------+ It also handles cases of duplicated sub nodes exported from older version of PyTorch: +----------------------+ | v | +-------> Sub-----------------------------------------------+ | | | | | v [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add | ^ | | +----------------------+ """ subgraph_nodes = [] children = self.model.get_children(node, input_name_to_nodes) if len(children) == 0 or len(children) > 2: return root_input = node.input[0] if children[0].op_type != "Sub" or children[0].input[0] != root_input: return if len(children) == 2: if children[1].op_type != "Sub" or children[1].input[0] != root_input: return div_node = None for child in children: # Check if Sub --> Div exists div_node_1 = self.model.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False) # Check if Sub --> Cast --> Div div_node_2 = self.model.match_child_path(child, ["Cast", "Div"], exclude=[]) if div_node_1 is not None: div_node = div_node_1 elif div_node_2 is not None: div_node = div_node_2[-1] if div_node is None: return path_id, parent_nodes, _ = self.model.match_parent_paths( div_node, [ (["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]), (["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"], [1, 0, 0, 0, 0, 0]), ], output_name_to_node, ) if path_id < 0: return sub_node = parent_nodes[-1] if sub_node not in children: return second_add_node = parent_nodes[1] i, add_weight = self.model.get_constant_input(second_add_node) if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4: logger.debug(f"skip SkipLayerNormalization fusion since epsilon value is not expected: {add_weight}") return pow_node = parent_nodes[3] if self.model.find_constant_input(pow_node, 2.0) != 1: return temp_node = input_name_to_nodes[div_node.output[0]][0] if temp_node.op_type == "Cast": # Div --> Cast --> Mul subgraph_nodes.append(temp_node) # add Cast node to list of subgraph nodes mul_node = input_name_to_nodes[temp_node.output[0]][0] else: # Div --> Mul mul_node = temp_node if mul_node.op_type != "Mul": return last_add_node = input_name_to_nodes[mul_node.output[0]][0] if last_add_node.op_type != "Add": return subgraph_nodes.append(node) subgraph_nodes.extend(children) subgraph_nodes.extend(parent_nodes[:-1]) subgraph_nodes.extend([last_add_node, mul_node, div_node]) if not self.model.is_safe_to_fuse_nodes( subgraph_nodes, last_add_node.output, input_name_to_nodes, output_name_to_node, ): logger.debug("It is not safe to fuse LayerNormalization node. Skip") return node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)] if not self.model.is_constant_with_specified_dimension(weight_input, 1, "layernorm weight"): return bias_input = last_add_node.input[1 - self.model.input_index(mul_node.output[0], last_add_node)] if not self.model.is_constant_with_specified_dimension(bias_input, 1, "layernorm bias"): return self.nodes_to_remove.extend(subgraph_nodes) normalize_node = helper.make_node( "LayerNormalization", inputs=[node.input[0], weight_input, bias_input], outputs=[last_add_node.output[0]], name=self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm"), ) normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))]) self.nodes_to_add.append(normalize_node) self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name class FusionLayerNormalizationNCHW(Fusion): def __init__(self, model: OnnxModel): super().__init__(model, "LayerNormalization", "ReduceMean") def get_weight_or_bias(self, output_name, description): value = self.model.get_constant_value(output_name) if value is None: logger.debug(f"{description} {output_name} is not initializer.") return None if len(value.shape) != 3 or value.shape[1] != 1 or value.shape[2] != 1: logger.debug(f"{description} {output_name} shall have 3 dimensions Cx1x1. Got shape {value.shape}") return None return value.reshape([value.shape[0]]) def create_transpose_node(self, input_name: str, perm: List[int], output_name=None): """Append a Transpose node after an input""" node_name = self.model.create_node_name("Transpose") if output_name is None: output_name = node_name + "_out" + "-" + input_name transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name) transpose_node.attribute.extend([helper.make_attribute("perm", perm)]) return transpose_node def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): """ Fuse Layer Normalization subgraph into one node LayerNormalization: +----------------------+ | NxCxHxW | | v (Cx1x1) (Cx1x1) [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add --> (axes=1) | (Y=2) (axes=1) (E-6) ^ | | +-----------------------------------------------+ Fused subgraph: (0,2,3,1) (0,3,1,2) [Root] --> Transpose --> LayerNormalization --> Transpose --> """ axes = OnnxModel.get_node_attribute(node, "axes") if (not isinstance(axes, list)) or axes != [1]: return subgraph_nodes = [] children = self.model.get_children(node, input_name_to_nodes) if len(children) != 1: return root_input = node.input[0] if children[0].op_type != "Sub" or children[0].input[0] != root_input: return sub = children[0] div_node = self.model.find_first_child_by_type(sub, "Div", input_name_to_nodes, recursive=False) if div_node is None: return parent_nodes = self.model.match_parent_path( div_node, ["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0], output_name_to_node, ) if parent_nodes is None: return _sqrt_node, second_add_node, reduce_mean_node, pow_node, sub_node = parent_nodes if sub != sub_node: return i, add_weight = self.model.get_constant_input(second_add_node) if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4: logger.debug(f"skip SkipLayerNormalization fusion since epsilon value is not expected: {add_weight}") return axes = OnnxModel.get_node_attribute(reduce_mean_node, "axes") assert isinstance(axes, list) if axes != [1]: return if self.model.find_constant_input(pow_node, 2.0) != 1: return temp_node = input_name_to_nodes[div_node.output[0]][0] mul_node = temp_node if mul_node.op_type != "Mul": return last_add_node = input_name_to_nodes[mul_node.output[0]][0] if last_add_node.op_type != "Add": return subgraph_nodes.append(node) subgraph_nodes.extend(parent_nodes) subgraph_nodes.extend([last_add_node, mul_node, div_node]) if not self.model.is_safe_to_fuse_nodes( subgraph_nodes, last_add_node.output, input_name_to_nodes, output_name_to_node, ): logger.debug("It is not safe to fuse LayerNormalization node. Skip") return node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)] weight = self.get_weight_or_bias(weight_input, "layernorm weight") if weight is None: return bias_input = last_add_node.input[1 - self.model.input_index(mul_node.output[0], last_add_node)] bias = self.get_weight_or_bias(bias_input, "layernorm bias") if bias is None: return weight_nhwc = helper.make_tensor(weight_input + "_NHWC", TensorProto.FLOAT, weight.shape, weight) bias_nhwc = helper.make_tensor(bias_input + "_NHWC", TensorProto.FLOAT, weight.shape, weight) self.model.add_initializer(weight_nhwc, self.this_graph_name) self.model.add_initializer(bias_nhwc, self.this_graph_name) self.nodes_to_remove.extend(subgraph_nodes) transpose_input = self.create_transpose_node(node.input[0], [0, 2, 3, 1]) layernorm_node_name = self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm") transpose_output = self.create_transpose_node( layernorm_node_name + "_out_nhwc", [0, 3, 1, 2], last_add_node.output[0] ) normalize_node = helper.make_node( "LayerNormalization", inputs=[transpose_input.output[0], weight_input + "_NHWC", bias_input + "_NHWC"], outputs=[layernorm_node_name + "_out_nhwc"], name=layernorm_node_name, ) normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))]) self.nodes_to_add.append(transpose_input) self.nodes_to_add.append(normalize_node) self.nodes_to_add.append(transpose_output) self.node_name_to_graph_name[transpose_input.name] = self.this_graph_name self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name self.node_name_to_graph_name[transpose_output.name] = self.this_graph_name counter_name = "LayerNormalization(NHWC)" self.increase_counter(counter_name) class FusionLayerNormalizationTF(Fusion): def __init__(self, model: OnnxModel): super().__init__(model, "LayerNormalization", "Add", "TF") def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): """ Layer Norm from Tensorflow model(using keras2onnx or tf2onnx): +------------------------------------+ | | | | (Cast_1) | | | | v (B) (B) (A) Add --> (Cast_1) --> ReduceMean --> Sub --> Mul --> ReduceMean --> (Cast_3) --> Add --> Sqrt --> Reciprocol --> Mul --> Mul --> Sub --> Add | | | ^ ^ | | | | | | +--------------------------------------------------(Cast_2)-------------------------------|-------+ | | v | +---------------------------------------------------------------------------------------------------------------> Mul--------------------+ """ return_indice = [] _, parent_nodes, return_indice = self.model.match_parent_paths( node, [ ( [ "Sub", "Mul", "Mul", "Reciprocal", "Sqrt", "Add", "ReduceMean", "Mul", "Sub", "ReduceMean", ], [1, 1, None, 0, 0, 0, None, 0, 0, None], ), ( [ "Sub", "Mul", "Mul", "Reciprocal", "Sqrt", "Add", "Cast", "ReduceMean", "Mul", "Sub", "ReduceMean", ], [1, 1, None, 0, 0, 0, 0, None, 0, 0, None], ), ], output_name_to_node, ) if parent_nodes is None: return assert len(return_indice) == 3 if not (return_indice[0] in [0, 1] and return_indice[1] in [0, 1] and return_indice[2] in [0, 1]): logger.debug("return indice is exepected in [0, 1], but got {return_indice}") return ( sub_node_0, mul_node_0, mul_node_1, reciprocol_node, sqrt_node, add_node_0, ) = parent_nodes[:6] reduce_mean_node_0, mul_node_2, sub_node_1, reduce_mean_node_1 = parent_nodes[-4:] cast_node_3 = None if len(parent_nodes) == 11: cast_node_3 = parent_nodes[6] assert cast_node_3.op_type == "Cast" mul_node_3 = self.model.match_parent(node, "Mul", 0, output_name_to_node) if mul_node_3 is None: logger.debug("mul_node_3 not found") return node_before_reduce = self.model.get_parent(reduce_mean_node_1, 0, output_name_to_node) root_node = ( node_before_reduce if cast_node_3 is None else self.model.get_parent(node_before_reduce, 0, output_name_to_node) ) if root_node is None: logger.debug("root node is none") return i, epsilon = self.model.get_constant_input(add_node_0) if epsilon is None or epsilon <= 0 or (epsilon > 1.0e-5 and cast_node_3 is None): logger.debug("epsilon is not matched") return if cast_node_3 is None and ( reduce_mean_node_1.input[0] not in mul_node_3.input or reduce_mean_node_1.input[0] not in sub_node_1.input ): logger.debug("reduce_mean_node_1 and mul_node_3 shall link from root node") return if cast_node_3 is not None and ( node_before_reduce.input[0] not in mul_node_3.input or reduce_mean_node_1.input[0] not in sub_node_1.input ): logger.debug("reduce_mean_node_1 and mul_node_3 shall link from root node") return if mul_node_2.input[0] != mul_node_2.input[1]: logger.debug("mul_node_2 shall have two same inputs") return subgraph_nodes = [ node, sub_node_0, mul_node_0, mul_node_1, reciprocol_node, sqrt_node, add_node_0, reduce_mean_node_0, mul_node_2, sub_node_1, reduce_mean_node_1, mul_node_3, ] if cast_node_3 is not None: cast_node_2 = self.model.match_parent(mul_node_0, "Cast", 0, output_name_to_node) if cast_node_2 is None: logger.debug("cast_node_2 not found") return subgraph_nodes.extend([node_before_reduce, cast_node_2, cast_node_3]) if not self.model.is_safe_to_fuse_nodes( subgraph_nodes, node.output, self.model.input_name_to_nodes(), self.model.output_name_to_node(), ): logger.debug("not safe to fuse layer normalization") return self.nodes_to_remove.extend(subgraph_nodes) weight_input = mul_node_1.input[1] bias_input = sub_node_0.input[0] # TODO: add epsilon attribute fused_node = helper.make_node( "LayerNormalization", inputs=[mul_node_3.input[0], weight_input, bias_input], outputs=[node.output[0]], name=self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm"), ) fused_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))]) self.nodes_to_add.append(fused_node) self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
Memory