# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- import logging from typing import Optional, Union from fusion_attention import FusionAttention from fusion_base import Fusion from onnx import FunctionProto, NodeProto, TensorProto, helper, numpy_helper from onnx_model import OnnxModel logger = logging.getLogger(__name__) class FusionRotaryAttention(FusionAttention): """ Fuse Attention subgraph with rotary positional embeddings into one MultiHeadAttention node. """ def __init__( self, model: OnnxModel, hidden_size: int, num_heads: int, ): super().__init__( model, hidden_size, num_heads, use_multi_head_attention=True, search_op_types=[ "SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization", "LayerNormalization", "SkipLayerNormalization", "Add", ], ) def create_mha_node( self, input: str, output: str, q_rotary: NodeProto, k_rotary: NodeProto, v_matmul: NodeProto, attn_mask: str = "", add_qk: str = "", past_k: str = "", past_v: str = "", present_k: str = "", present_v: str = "", scale: Optional[float] = None, ) -> Union[NodeProto, None]: assert self.num_heads > 0 if self.hidden_size > 0 and (self.hidden_size % self.num_heads) != 0: logger.debug( f"fuse_rotary_attention: input hidden size {self.hidden_size} is not a multiple of num of heads {self.num_heads}" ) return None mha_node_name = self.model.create_node_name("MultiHeadAttention") mha_inputs = [ q_rotary.output[0], k_rotary.output[0], v_matmul.output[0], "", # bias attn_mask, # key_padding_mask add_qk, # attention_bias past_k, past_v, ] mha_outputs = [output] if present_k and present_v: mha_outputs.extend([present_k, present_v]) mha_node = helper.make_node( "MultiHeadAttention", inputs=mha_inputs, outputs=mha_outputs, name=mha_node_name, ) mha_node.domain = "com.microsoft" mha_node.attribute.extend([helper.make_attribute("num_heads", self.num_heads)]) if scale is not None: mha_node.attribute.extend([helper.make_attribute("scale", scale)]) if self.mask_filter_value is not None: mha_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))]) self.increase_counter("MultiHeadAttention") return mha_node def check_runtime_shape_paths_for_function( self, reshape_qkv_2, # Reshape after Transpose reshape_qkv_1, # Reshape before Transpose reshape_q_2, # Reshape after RotaryEmbedding reshape_k_2, # Reshape after RotaryEmbedding reshape_v_2, # Reshape after Transpose reshape_v_1, # Reshape before Transpose add_qk, # Add before Softmax root_input, # Root input to attention subgraph ): # Check #1: check paths for qkv nodes concat_qkv_2_path = self.model.match_parent_path(reshape_qkv_2, ["Concat"], [1]) concat_qkv_1_path = self.model.match_parent_path(reshape_qkv_1, ["Concat"], [1]) if concat_qkv_2_path is None or concat_qkv_1_path is None: return False concat_qkv_2, concat_qkv_1 = concat_qkv_2_path[0], concat_qkv_1_path[0] reshape_qkv_2_path_1 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) reshape_qkv_2_path_2 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) reshape_qkv_1_path_1 = self.model.match_parent_path(concat_qkv_1, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) reshape_qkv_1_path_2 = self.model.match_parent_path(concat_qkv_1, ["Unsqueeze", "Gather", "Shape"], [2, 0, 0]) if ( reshape_qkv_2_path_1 is None or reshape_qkv_2_path_2 is None or reshape_qkv_1_path_1 is None or reshape_qkv_1_path_2 is None ): return False _, gather_1, shape_1 = reshape_qkv_2_path_1 _, gather_2, shape_2 = reshape_qkv_2_path_2 # Check root_input --> Shape --> Gather connection if shape_1.input[0] != root_input or shape_2.input[0] != root_input: return False # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_qkv_1_path_1 and reshape_qkv_1_path_2 if reshape_qkv_1_path_1[1].name != gather_1.name or reshape_qkv_1_path_2[1].name != gather_2.name: return False # Check #2: check paths for v nodes concat_v_2_path = self.model.match_parent_path(reshape_v_2, ["Concat"], [1]) concat_v_1_path = self.model.match_parent_path(reshape_v_1, ["Concat"], [1]) if concat_v_2_path is None or concat_v_1_path is None: return False concat_v_2, concat_v_1 = concat_v_2_path[0], concat_v_1_path[0] reshape_v_2_path_1 = self.model.match_parent_path( concat_v_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0] ) reshape_v_2_path_2 = self.model.match_parent_path( concat_v_2, ["Unsqueeze", "Add", "Gather", "Shape"], [1, 0, 0, 0] ) reshape_v_1_path_1 = self.model.match_parent_path(concat_v_1, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) reshape_v_1_path_2 = self.model.match_parent_path(concat_v_1, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) if ( reshape_v_2_path_1 is None or reshape_v_2_path_2 is None or reshape_v_1_path_1 is None or reshape_v_1_path_2 is None ): return False # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_v_2_path_1 # Check Gather --> Add --> Unsqueeze --> Concat --> Reshape connection for reshape_v_2_path_2 # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_v_1_path_1 and reshape_v_1_path_2 if ( reshape_v_2_path_1[2].name != gather_1.name or reshape_v_2_path_2[2].name != gather_2.name or reshape_v_1_path_1[1].name != gather_1.name or reshape_v_1_path_2[1].name != gather_2.name ): return False # Check #3: check paths for k nodes concat_k_2_path = self.model.match_parent_path(reshape_k_2, ["Concat"], [1]) if concat_k_2_path is None: return False concat_k_2 = concat_k_2_path[0] reshape_k_2_path_1 = self.model.match_parent_path( concat_k_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0] ) reshape_k_2_path_2 = self.model.match_parent_path( concat_k_2, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 0, 0] ) if reshape_k_2_path_1 is None or reshape_k_2_path_2 is None: return False # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_k_2_path_1 # Check Gather --> Add --> Unsqueeze --> Concat --> Reshape connection for reshape_k_2_path_2 if reshape_k_2_path_1[2].name != gather_1.name or reshape_k_2_path_2[2].name != gather_2.name: return False # Check #4: check paths for q nodes concat_q_2_path = self.model.match_parent_path(reshape_q_2, ["Concat"], [1]) if concat_q_2_path is None: return False concat_q_2 = concat_q_2_path[0] reshape_q_2_path_1 = self.model.match_parent_path( concat_q_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0] ) reshape_q_2_path_2 = self.model.match_parent_path(concat_q_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) if reshape_q_2_path_1 is None or reshape_q_2_path_2 is None: return False # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_q_2_path_1 # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_q_2_path_2 if reshape_q_2_path_1[2].name != gather_1.name or reshape_q_2_path_2[1].name != gather_2.name: return False # Check #5: check Mul nodes are the same for q, k, v mul_q = reshape_q_2_path_1[1] mul_k = reshape_k_2_path_1[1] mul_v = reshape_v_2_path_1[1] gather_1_out = gather_1.output[0] if mul_q.input[0] != gather_1_out or mul_k.input[0] != gather_1_out or mul_v.input[0] != gather_1_out: return False # Check #6: check paths for attention mask nodes attn_mask_path_1 = self.model.match_parent_path(add_qk, ["Concat", "Slice", "Slice"], [1, 0, 0]) attn_mask_path_2 = self.model.match_parent_path(add_qk, ["Cast", "Concat", "Slice", "Slice"], [1, 0, 0, 0]) if attn_mask_path_1 is not None: _, slice_qk_2, slice_qk_1 = attn_mask_path_1 elif attn_mask_path_2 is not None: _, _, slice_qk_2, slice_qk_1 = attn_mask_path_2 else: return False # Check first input to Slice #1 is 3D attention mask of shape (B,S,T) if slice_qk_1.input[0] not in {"attn_mask", "attention_mask"}: return False slice_qk_2_path = self.model.match_parent_path( slice_qk_2, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 1, 0] ) slice_qk_1_path_1 = self.model.match_parent_path( slice_qk_1, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 1, 0] ) slice_qk_1_path_2 = self.model.match_parent_path(slice_qk_1, ["Unsqueeze"], [1]) if slice_qk_2_path is None or slice_qk_1_path_1 is None or slice_qk_1_path_2 is None: return False # Check Gather --> Add --> Unsqueeze #3 --> Slice #2 connection for slice_qk_2_path # Check Gather --> Add --> Unsqueeze #2 --> Slice #1 connection for slice_qk_1_path_1 if slice_qk_2_path[1].name != slice_qk_1_path_1[1].name or slice_qk_2_path[2].name != slice_qk_1_path_1[2].name: return False # Check Unsqueeze #1 --> Slice #1 connection for slice_qk_1_path_2 # Check if first input to Add and Unsqueeze #1 is position ids if slice_qk_1_path_1[1].input[0] != slice_qk_1_path_2[0].input[0]: return False return True def check_runtime_shape_paths_for_nodes( self, reshape_qkv, # Final reshape before o_proj MatMul reshape_q, # Reshape before q_proj MatMul reshape_k, # Reshape before k_proj MatMul reshape_v, # Reshape before v_proj MatMul root_input, # Root input to attention subgraph ): # Check #1: check paths for qkv nodes concat_qkv_path = self.model.match_parent_path(reshape_qkv, ["Concat"], [1]) if concat_qkv_path is None: return False concat_qkv = concat_qkv_path[0] reshape_qkv_path_1 = self.model.match_parent_path(concat_qkv, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) reshape_qkv_path_2 = self.model.match_parent_path(concat_qkv, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) if reshape_qkv_path_1 is None or reshape_qkv_path_2 is None: return False _, gather_1, shape_1 = reshape_qkv_path_1 _, gather_2, shape_2 = reshape_qkv_path_2 # Check root_input --> Shape --> Gather connection if shape_1.input[0] != root_input or shape_2.input[0] != root_input: return False # Check #2: check paths for v nodes concat_v_path = self.model.match_parent_path(reshape_v, ["Concat"], [1]) if concat_v_path is None: return False concat_v = concat_v_path[0] reshape_v_path_1 = self.model.match_parent_path(concat_v, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) reshape_v_path_2 = self.model.match_parent_path(concat_v, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) if reshape_v_path_1 is None or reshape_v_path_2 is None: return False # Check Gather --> Unsqueeze --> Concat --> Reshape connection if reshape_v_path_1[1].name != gather_1.name or reshape_v_path_2[1].name != gather_2.name: return False # Check #3: check paths for k nodes concat_k_path = self.model.match_parent_path(reshape_k, ["Concat"], [1]) if concat_k_path is None: return False concat_k = concat_k_path[0] reshape_k_path_1 = self.model.match_parent_path(concat_k, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) reshape_k_path_2 = self.model.match_parent_path(concat_k, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) if reshape_k_path_1 is None or reshape_k_path_2 is None: return False # Check Gather --> Unsqueeze --> Concat --> Reshape connection if reshape_k_path_1[1].name != gather_1.name or reshape_k_path_2[1].name != gather_2.name: return False # Check #4: check paths for q nodes concat_q_path = self.model.match_parent_path(reshape_q, ["Concat"], [1]) if concat_q_path is None: return False concat_q = concat_q_path[0] reshape_q_path_1 = self.model.match_parent_path(concat_q, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) reshape_q_path_2 = self.model.match_parent_path(concat_q, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) if reshape_q_path_1 is None or reshape_q_path_2 is None: return False # Check Gather --> Unsqueeze --> Concat --> Reshape connection if reshape_q_path_1[1].name != gather_1.name or reshape_q_path_2[1].name != gather_2.name: return False return True def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): if normalize_node.op_type not in {"SkipSimplifiedLayerNormalization", "SkipLayerNormalization", "Add"}: return # qkv_nodes_1 is for LLaMA-2 Microsoft # qkv_nodes_2 is for LLaMA-2 Hugging Face # qkv_nodes_3 is for LLaMA-2 distribute Hugging Face model qkv_nodes = None qkv_nodes_1 = self.model.match_parent_path( normalize_node, ["MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0, 0], ) qkv_nodes_2 = self.model.match_parent_path( normalize_node, ["MatMul", "Reshape", "Transpose", "MatMul"], [1, 0, 0, 0], ) qkv_nodes_3 = self.model.match_parent_path( normalize_node, ["AllReduce", "MatMul", "Reshape", "Transpose", "MatMul"], [1, 0, 0, 0, 0], ) if qkv_nodes_1 is not None: _, reshape_qkv_2, _, reshape_qkv_1, matmul_qkv = qkv_nodes_1 qkv_nodes = qkv_nodes_1 elif qkv_nodes_2 is not None: _, reshape_qkv, _, matmul_qkv = qkv_nodes_2 qkv_nodes = qkv_nodes_2 elif qkv_nodes_3 is not None: _, _, reshape_qkv, _, matmul_qkv = qkv_nodes_3 qkv_nodes = qkv_nodes_3 else: logger.debug("fuse_rotary_attention: failed to match qkv nodes") return # v_nodes_1 is for LLaMA-2 Microsoft # v_nodes_3 is for LLaMA-2 Hugging Face # v_nodes_4 is for LLaMA-2 70B model # v_nodes_5 is for Phi-2 DirectML past_v, present_v, past_seq_len = "", "", "" v_nodes = None add_v = None v_nodes_1 = self.model.match_parent_path( matmul_qkv, ["Reshape", "Transpose", "Concat", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 1, 0, 0], ) v_nodes_2 = self.model.match_parent_path( matmul_qkv, ["Concat", "Transpose", "Reshape", "MatMul"], [1, 1, 0, 0], ) v_nodes_3 = self.model.match_parent_path( matmul_qkv, ["Transpose", "Reshape", "MatMul"], [1, 0, 0], ) _, v_nodes_4, _ = self.model.match_parent_paths_all( matmul_qkv, [ ( ["Reshape", "Expand", "Unsqueeze", "Concat", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0, 1, 0, 0], ), ( [ "Reshape", "Expand", "Where", "Equal", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul", ], [1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], ), ( [ "Reshape", "Expand", "Where", "Equal", "Mul", "ConstantOfShape", "Shape", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul", ], [1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0], ), ( [ "Reshape", "Expand", "Where", "ConstantOfShape", "Shape", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul", ], [1, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0], ), ( [ "Reshape", "Expand", "Where", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul", ], [1, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0], ), ( ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], [1, 1, 0, 0, 0, 0, 1, 0, 0], ), ( [ "Reshape", "Concat", "Unsqueeze", "Mul", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul", ], [1, 1, 1, 0, 0, 0, 0, 1, 0, 0], ), ( ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], [1, 1, 2, 0, 0, 0, 1, 0, 0], ), ( ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], [1, 1, 3, 0, 0, 0, 1, 0, 0], ), ], output_name_to_node=None, ) v_nodes_5 = self.model.match_parent_path( matmul_qkv, ["Concat", "Transpose", "Reshape", "Add", "MatMul"], [1, 1, 0, 0, 1], ) if v_nodes_1 is not None: reshape_v_2, _, concat_v, _, reshape_v_1, matmul_v = v_nodes_1 v_nodes = v_nodes_1 concat_v_path = self.model.match_parent_path( concat_v, ["Slice", "Unsqueeze"], [0, 2], ) if concat_v_path is None: logger.debug("fuse_rotary_attention: failed to match past/present concat in v path") return past_v = concat_v_path[0].input[0] past_seq_len = concat_v_path[-1].input[0] present_v = concat_v.output[0] elif v_nodes_2 is not None: concat_v, transpose_v, reshape_v, matmul_v = v_nodes_2 v_nodes = v_nodes_2 past_v = concat_v.input[0] present_v = concat_v.output[0] elif v_nodes_3 is not None: transpose_v, reshape_v, matmul_v = v_nodes_3 v_nodes = v_nodes_3 present_v = transpose_v.output[0] elif v_nodes_4 is not None and len(v_nodes_4) == 9: concat_v, transpose_v, reshape_v, matmul_v = v_nodes_4[0][-4:] v_nodes = v_nodes_4 past_v = concat_v.input[0] present_v = concat_v.output[0] elif v_nodes_5 is not None: concat_v, transpose_v, reshape_v, add_v, matmul_v = v_nodes_5 matmul_v = add_v v_nodes = v_nodes_5 past_v = concat_v.input[0] present_v = concat_v.output[0] else: logger.debug("fuse_rotary_attention: failed to match v path") return qk_nodes = self.model.match_parent_path( matmul_qkv, ["Softmax", "Add", "Div", "MatMul"], [0, 0, 0, 0], ) add_qk, matmul_qk = None, None if qk_nodes is not None: _, add_qk, _, matmul_qk = qk_nodes else: logger.debug("fuse_rotary_attention: failed to match qk nodes") return # attn_mask_nodes_1, attn_mask_nodes_2 are for LLaMA-2 Microsoft's 3D attention mask # attn_mask_nodes_3, attn_mask_nodes_4 are for LLaMA-2 Hugging Face's 2D attention mask # attn_mask_nodes_5, attn_mask_nodes_6 are for LLaMA-2 Microsoft's model for the DML EP # attn_mask_nodes_7 is for LLaMA-2 Hugging Face's changes to the attention mask attn_mask, add_qk_str = "", "" attn_mask_nodes_1 = self.model.match_parent_path( add_qk, ["Concat", "Slice", "Slice"], [1, 0, 0], ) attn_mask_nodes_2 = self.model.match_parent_path( add_qk, ["Cast", "Concat", "Slice", "Slice"], [1, 0, 0, 0], ) attn_mask_nodes_3 = self.model.match_parent_path( add_qk, ["Add", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], [1, 0, 2, 1, 0, 0, 0], ) attn_mask_nodes_4 = self.model.match_parent_path( add_qk, ["Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], [1, 2, 1, 0, 0, 0], ) attn_mask_nodes_5 = self.model.match_parent_path( add_qk, ["Expand", "Add", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], [1, 0, 0, 2, 1, 0, 0, 0], ) attn_mask_nodes_6 = self.model.match_parent_path( add_qk, ["Expand", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], [1, 0, 2, 1, 0, 0, 0], ) attn_mask_nodes_7 = self.model.match_parent_path( add_qk, ["Where", "Cast", "Where", "Cast", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], [1, 0, 0, 0, 0, 1, 0, 0, 0], ) if attn_mask_nodes_1 is not None: _, slice_mask_1, slice_mask_2 = attn_mask_nodes_1 attn_mask = slice_mask_1.output[0] elif attn_mask_nodes_2 is not None: _, _, slice_mask_1, slice_mask_2 = attn_mask_nodes_2 attn_mask = slice_mask_1.output[0] elif attn_mask_nodes_3 is not None: # Reshape from (B,1,S,T) to (B,N,S,T) add_qk_str = self.reshape_add_qk(attn_mask_nodes_3[0].output[0]) elif attn_mask_nodes_4 is not None: # Reshape from (B,1,S,T) to (B,N,S,T) add_qk_str = self.reshape_add_qk(attn_mask_nodes_4[0].output[0]) elif attn_mask_nodes_5 is not None: # The mask has already been reshaped to (B,N,S,T) add_qk_str = attn_mask_nodes_5[0].output[0] elif attn_mask_nodes_6 is not None: # The mask has already been reshaped to (B,N,S,T) add_qk_str = attn_mask_nodes_6[0].output[0] elif attn_mask_nodes_7 is not None: # Reshape from (B,1,S,T) to (B,N,S,T) add_qk_str = self.reshape_add_qk(attn_mask_nodes_7[0].output[0]) else: logger.debug("fuse_rotary_attention: failed to match attention mask nodes") return # k_nodes_1 is for LLaMA-2 Microsoft # k_nodes_2 is for LLaMA-2 Hugging Face # k_nodes_4 is for LLaMA-2 70B Hugging Face past_k, present_k = "", "" k_nodes = None slice_k = None concat_k_half = None k_nodes_1 = self.model.match_parent_path( matmul_qk, ["Reshape", "Transpose", "Concat", "Transpose", "RotaryEmbedding", "MatMul"], [1, 0, 0, 1, 0, 0], ) k_nodes_2 = self.model.match_parent_path( matmul_qk, ["Transpose", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0, 0], ) k_nodes_3 = self.model.match_parent_path( matmul_qk, ["Transpose", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], [1, 0, 1, 0, 0, 0], ) _, k_nodes_4, _ = self.model.match_parent_paths_all( matmul_qk, [ ( [ "Transpose", "Reshape", "Expand", "Unsqueeze", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul", ], [1, 0, 0, 0, 0, 1, 0, 0, 0], ), ( [ "Transpose", "Reshape", "Expand", "Where", "Equal", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul", ], [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], ), ( [ "Transpose", "Reshape", "Expand", "Where", "Equal", "Mul", "ConstantOfShape", "Shape", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul", ], [1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0], ), ( [ "Transpose", "Reshape", "Expand", "Where", "ConstantOfShape", "Shape", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul", ], [1, 0, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0], ), ( [ "Transpose", "Reshape", "Expand", "Where", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul", ], [1, 0, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0, 0], ), ( [ "Transpose", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul", ], [1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0], ), ( [ "Transpose", "Reshape", "Concat", "Unsqueeze", "Mul", "Gather", "Shape", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul", ], [1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0], ), ( [ "Transpose", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul", ], [1, 0, 1, 2, 0, 0, 0, 1, 0, 0, 0], ), ( [ "Transpose", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul", ], [1, 0, 1, 3, 0, 0, 0, 1, 0, 0, 0], ), ], output_name_to_node=None, ) k_nodes_5 = self.model.match_parent_path( matmul_qk, ["Transpose", "Concat", "Concat", "RotaryEmbedding", "Slice", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 1, 0, 0, 0, 0, 0, 1], ) if k_nodes_1 is not None: reshape_k_2, _, concat_k, _, rotary_k, matmul_k = k_nodes_1 k_nodes = k_nodes_1 concat_k_path = self.model.match_parent_path( concat_k, ["Slice", "Unsqueeze"], [0, 2], ) if concat_k_path is None: logger.debug("fuse_rotary_attention: failed to match past/present concat in k path") return past_k = concat_k_path[0].input[0] shared_past_seq_len = concat_k_path[-1].input[0] present_k = concat_k.output[0] assert past_seq_len == shared_past_seq_len elif k_nodes_2 is not None: _, rotary_k, _, reshape_k, matmul_k = k_nodes_2 k_nodes = k_nodes_2 present_k = rotary_k.output[0] elif k_nodes_3 is not None: _, concat_k, rotary_k, _, reshape_k, matmul_k = k_nodes_3 k_nodes = k_nodes_3 past_k = concat_k.input[0] present_k = concat_k.output[0] elif k_nodes_4 is not None and len(k_nodes_4) == 9: reshape_k, matmul_k = k_nodes_4[0][-2:] concat_k, rotary_k = k_nodes_4[0][-5:-3] k_nodes = k_nodes_4 past_k = concat_k.input[0] present_k = concat_k.output[0] elif k_nodes_5 is not None: _, concat_k, concat_k_half, rotary_k, slice_k, _, reshape_k, _, matmul_k = k_nodes_5 k_nodes = k_nodes_5 past_k = concat_k.input[0] present_k = concat_k.output[0] else: logger.debug("fuse_rotary_attention: failed to match k nodes") return # q_nodes_1 is for LLaMA-2 Microsoft # q_nodes_2 is for LLaMA-2 Hugging Face # q_nodes_3 is for Phi-2 DirectML q_nodes = None slice_q = None concat_q_half = None q_nodes_1 = self.model.match_parent_path( matmul_qk, ["Reshape", "Transpose", "RotaryEmbedding", "MatMul"], [0, 0, 0, 0], ) q_nodes_2 = self.model.match_parent_path( matmul_qk, ["RotaryEmbedding", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0], ) q_nodes_3 = self.model.match_parent_path( matmul_qk, ["Concat", "RotaryEmbedding", "Slice", "Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, 0, 0, 0, 1], ) if q_nodes_1 is not None: reshape_q_2, _, rotary_q, matmul_q = q_nodes_1 q_nodes = q_nodes_1 elif q_nodes_2 is not None: rotary_q, _, reshape_q, matmul_q = q_nodes_2 q_nodes = q_nodes_2 elif q_nodes_3 is not None: concat_q_half, rotary_q, slice_q, _, reshape_q, _, matmul_q = q_nodes_3 q_nodes = q_nodes_3 else: logger.debug("fuse_rotary_attention: failed to match q nodes") return if matmul_q.input[0] != matmul_k.input[0] and matmul_k.input[0] != matmul_v.input[0]: logger.debug("fuse_rotary_attention: failed to find the same root_input for q, k, v paths") return root_output = "" if qkv_nodes == qkv_nodes_1: if not self.check_runtime_shape_paths_for_function( reshape_qkv_2, reshape_qkv_1, reshape_q_2, reshape_k_2, reshape_v_2, reshape_v_1, add_qk, matmul_q.input[0], ): logger.debug("fuse_rotary_attention: failed to verify runtime shape paths") return root_output = reshape_qkv_2.output[0] elif qkv_nodes in (qkv_nodes_2, qkv_nodes_3): if not self.check_runtime_shape_paths_for_nodes( reshape_qkv, reshape_q, reshape_k, reshape_v, matmul_q.input[0], ): logger.debug("fuse_rotary_attention: failed to verify runtime shape paths") return root_output = reshape_qkv.output[0] # Rename inputs of rotary_q/k so it connects with output of matmul_q/k # Before: MatMul --> Reshape --> Transpose --> RotaryEmbedding # After: MatMul --> RotaryEmbedding rotary_q.input[0] = slice_q.output[0] if slice_q else matmul_q.output[0] rotary_k.input[0] = slice_k.output[0] if slice_k else matmul_k.output[0] # Rename current output of rotary_k (present_key) so it doesn't match output of MHA (present_key) if concat_q_half is None: rotary_k.output[0] = rotary_k.name + "_output_0" if qkv_nodes == qkv_nodes_3: qkv_nodes = qkv_nodes[1:] def create_hidden_size_concat_node(reshape_q): """Detect num_heads and hidden_size for ONNX model from phi-2 Args: reshape_q (NodeProto): reshape node for q Returns: hidden_size_concat_node(NodeProto): Concat node to be used by reshape """ concat = self.model.match_parent(reshape_q, "Concat", 1) if concat is None: logger.debug("fuse_rotary_attention: failed to trace the concat node from reshape_q") return None # The shape is a tensor like [?, ?, num_heads, head_size] num_head_constant_node = self.model.get_constant_value(concat.input[2]) head_size_constant_node = self.model.get_constant_value(concat.input[3]) if num_head_constant_node is None or head_size_constant_node is None: logger.debug("fuse_rotary_attention: failed to get constant nodes of num_heads or head_size") return None num_head_value = num_head_constant_node[0] head_size_value = head_size_constant_node[0] hidden_size = num_head_value * head_size_value hidden_size_initilizer = self.model.create_node_name("Initializer", name_prefix="hidden_size") if self.model.get_initializer(hidden_size_initilizer) is None: self.add_initializer( name=hidden_size_initilizer, data_type=TensorProto.INT64, dims=[1], vals=[hidden_size], raw=False, ) hidden_size_reshape_node_name = self.model.create_node_name("Concat", name_prefix="hidden_size_concat") hidden_size_concat_node = helper.make_node( "Concat", inputs=[ concat.input[0], concat.input[1], hidden_size_initilizer, ], outputs=[hidden_size_reshape_node_name + "output_0"], name=hidden_size_reshape_node_name, ) hidden_size_concat_node.attribute.extend([helper.make_attribute("axis", 0)]) return hidden_size_concat_node # Add Tranpose and Reshape nodes for patial rotary embedding applied in phi-2 before passing into MHA if concat_q_half and concat_k_half: # Transpose the key output of rotary Embedding k_transpose_node_name = self.model.create_node_name("Transpose") k_tranpose_output_name = k_transpose_node_name + "_output_0" k_transpose_node = helper.make_node( "Transpose", inputs=[concat_k_half.output[0]], outputs=[k_tranpose_output_name], name=k_transpose_node_name, ) k_transpose_node.attribute.extend([helper.make_attribute("perm", [0, 2, 1, 3])]) # Transpose the query output of rotary Embedding q_transpose_node_name = self.model.create_node_name("Transpose") q_tranpose_output_name = q_transpose_node_name + "_output_0" q_transpose_node = helper.make_node( "Transpose", inputs=[concat_q_half.output[0]], outputs=[q_tranpose_output_name], name=q_transpose_node_name, ) q_transpose_node.attribute.extend([helper.make_attribute("perm", [0, 2, 1, 3])]) hidden_size_concat_node = create_hidden_size_concat_node(reshape_k) if hidden_size_concat_node is None: logger.debug("fuse_rotary_attention: failed to create hidden_size_concat_node") return # Reshape the Rotary Embedding output for key for 4D to 3D concat_k_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="concat_k_half") concat_k_reshape_node = helper.make_node( "Reshape", inputs=[k_transpose_node.output[0], hidden_size_concat_node.output[0]], outputs=[concat_k_reshape_node_name + "_output_0"], name=concat_k_reshape_node_name, ) # Reshape the Rotary Embedding output for query from 4D to 3D concat_q_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="concat_q_half") concat_q_reshape_node = helper.make_node( "Reshape", inputs=[q_transpose_node.output[0], hidden_size_concat_node.output[0]], outputs=[concat_q_reshape_node_name + "_output_0"], name=concat_q_reshape_node_name, ) rotary_k = concat_k_reshape_node rotary_q = concat_q_reshape_node self.nodes_to_add.append(hidden_size_concat_node) self.nodes_to_add.append(k_transpose_node) self.nodes_to_add.append(q_transpose_node) self.nodes_to_add.append(concat_k_reshape_node) self.nodes_to_add.append(concat_q_reshape_node) self.node_name_to_graph_name[hidden_size_concat_node.name] = self.this_graph_name self.node_name_to_graph_name[k_transpose_node.name] = self.this_graph_name self.node_name_to_graph_name[q_transpose_node.name] = self.this_graph_name self.node_name_to_graph_name[concat_k_reshape_node.name] = self.this_graph_name self.node_name_to_graph_name[concat_q_reshape_node.name] = self.this_graph_name new_node = self.create_mha_node( matmul_q.input[0], root_output, rotary_q, rotary_k, matmul_v, attn_mask, add_qk_str, past_k, past_v, present_k, present_v, ) if new_node is None: logger.debug("fuse_rotary_attention: failed to create multi-head attention with rotary embeddings") return self.nodes_to_add.append(new_node) self.node_name_to_graph_name[new_node.name] = self.this_graph_name self.nodes_to_remove.extend(qkv_nodes[1:]) if v_nodes != v_nodes_4: self.nodes_to_remove.extend(v_nodes[:-1] if add_v is None else v_nodes[:-2]) else: nodes_to_keep = [v_nodes[0][-1]] for temp_path in v_nodes: self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep) self.nodes_to_remove.extend(qk_nodes) if k_nodes == k_nodes_1: self.nodes_to_remove.extend(k_nodes[:-2]) elif k_nodes == k_nodes_2: self.nodes_to_remove.append(k_nodes[0]) self.nodes_to_remove.append(k_nodes[2]) self.nodes_to_remove.append(k_nodes[3]) elif k_nodes == k_nodes_3: self.nodes_to_remove.append(k_nodes[0]) self.nodes_to_remove.append(k_nodes[1]) self.nodes_to_remove.append(k_nodes[3]) self.nodes_to_remove.append(k_nodes[4]) elif k_nodes == k_nodes_5: self.nodes_to_remove.append(k_nodes[0]) self.nodes_to_remove.append(k_nodes[1]) elif k_nodes == k_nodes_4: nodes_to_keep = [k_nodes[0][-1], k_nodes[0][-4]] for temp_path in k_nodes: self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep) if q_nodes == q_nodes_1: self.nodes_to_remove.extend(q_nodes[:-2]) elif q_nodes == q_nodes_2: self.nodes_to_remove.append(q_nodes[1]) self.nodes_to_remove.append(q_nodes[2]) self.prune_graph = True class FusionRotaryEmbeddings(Fusion): def __init__(self, model: OnnxModel): self.base_name = "RotaryEmbedding" super().__init__(model, self.base_name, [self.base_name, self.base_name + ".1", "Add"]) # The RotaryEmbedding function can have multiple extraneous constant outputs even though the function is supposed to produce only one output. # This is a byproduct of a potential CSE bug when using `export_modules_as_functions` in the TorchScript exporter. # To work around this issue, we set the extraneous constant values from the RotaryEmbedding function as initializers in the locations where they are actually used. def reassign_extra_outputs(self, rot_emb_node: NodeProto, function: FunctionProto): # Find extra outputs and Constant nodes attached to those outputs extra_constants, extra_outputs = [], [] for fn_node in function.node: if fn_node.op_type == "Constant" and fn_node.input == [] and fn_node.output[0] in function.output: extra_constants.append(fn_node) output_index = list(function.output).index(fn_node.output[0]) extra_outputs.append(rot_emb_node.output[output_index]) # Set extra Constant node outputs as initializers extra_initializers = [] for extra_constant in extra_constants: constant_tensorproto = extra_constant.attribute[0].t constant_tensorproto.name = self.model.create_node_name("Constant") self.model.add_initializer(constant_tensorproto) extra_initializers.append(constant_tensorproto.name) # Update references of Constant node outputs to initializer references for extra_output, extra_initializer in zip(extra_outputs, extra_initializers): nodes_to_update = list(filter(lambda entry: extra_output in entry.input, self.model.model.graph.node)) for node_to_update in nodes_to_update: OnnxModel.replace_node_input(node_to_update, extra_output, extra_initializer) return extra_outputs def create_rotary_embeddings_from_function(self, node: NodeProto): rotary_emb_node_name = self.model.create_node_name(self.base_name) matmul_path = self.model.match_parent_path( node, ["Reshape", "MatMul"], [0, 0], ) if matmul_path is not None: reshape_node, matmul_node = matmul_path else: logger.debug("fuse_rotary_embeddings: failed to match MatMul") return rotary_emb_inputs = [ matmul_node.output[0], # x is of shape (B,S,D) instead of (B,S,N,H) node.input[1], # position_ids ] # Convert cos_cache and sin_cache from node attributes to model initializers cos_cache_node = list(filter(lambda constant: constant.output[0] == node.input[2], self.model.model.graph.node)) sin_cache_node = list(filter(lambda constant: constant.output[0] == node.input[3], self.model.model.graph.node)) cos_cache_name, sin_cache_name = "cos_cache", "sin_cache" if ( len(cos_cache_node) == 1 and len(sin_cache_node) == 1 and self.model.get_initializer(cos_cache_name) is None and self.model.get_initializer(sin_cache_name) is None ): cos_cache = numpy_helper.to_array(cos_cache_node[0].attribute[0].t).squeeze() sin_cache = numpy_helper.to_array(sin_cache_node[0].attribute[0].t).squeeze() cos_cache_tensor = helper.make_tensor( name=cos_cache_name, data_type=TensorProto.FLOAT, dims=list(cos_cache.shape), vals=cos_cache.flatten().tolist(), ) self.model.add_initializer(cos_cache_tensor, self.this_graph_name) sin_cache_tensor = helper.make_tensor( name=sin_cache_name, data_type=TensorProto.FLOAT, dims=list(sin_cache.shape), vals=sin_cache.flatten().tolist(), ) self.model.add_initializer(sin_cache_tensor, self.this_graph_name) self.nodes_to_remove.extend([cos_cache_node[0], sin_cache_node[0]]) rotary_emb_inputs.extend([cos_cache_name, sin_cache_name]) rotary_emb_outputs = node.output if len(rotary_emb_outputs) > 1: # Re-assign extraneous constant outputs in RotaryEmbedding functions as initializers func = list(filter(lambda fn: fn.name == node.op_type, self.model.model.functions)) assert len(func) == 1 extra_outputs = self.reassign_extra_outputs(node, func[0]) rotary_emb_outputs = list(filter(lambda output_name: output_name not in extra_outputs, rotary_emb_outputs)) assert len(rotary_emb_outputs) == 1 rotary_emb_node = helper.make_node( self.base_name, inputs=rotary_emb_inputs, outputs=rotary_emb_outputs, name=rotary_emb_node_name, interleaved=1, ) rotary_emb_node.domain = "com.microsoft" self.nodes_to_remove.append(reshape_node) return rotary_emb_node def create_rotary_embeddings_from_nodes( self, root_input: str, position_ids: str, cos_slice: str, sin_slice: str, output: str, ): rotary_emb_node_name = self.model.create_node_name(self.base_name) # Convert cos_cache and sin_cache from node attributes to model initializers cos_cache_node = list(filter(lambda constant: constant.output[0] == cos_slice, self.model.model.graph.node)) sin_cache_node = list(filter(lambda constant: constant.output[0] == sin_slice, self.model.model.graph.node)) cos_cache_name, sin_cache_name = "cos_cache", "sin_cache" if ( len(cos_cache_node) == 1 and len(sin_cache_node) == 1 and self.model.get_initializer(cos_cache_name) is None and self.model.get_initializer(sin_cache_name) is None ): cos_cache = numpy_helper.to_array(cos_cache_node[0].attribute[0].t).squeeze() sin_cache = numpy_helper.to_array(sin_cache_node[0].attribute[0].t).squeeze() # Reshape cos/sin cache from (M, H) to (M, H/2) head_size = cos_cache.shape[1] cos_cache = cos_cache[:, : (head_size // 2)] sin_cache = sin_cache[:, : (head_size // 2)] cos_cache_tensor = helper.make_tensor( name=cos_cache_name, data_type=TensorProto.FLOAT, dims=list(cos_cache.shape), vals=cos_cache.flatten().tolist(), ) self.model.add_initializer(cos_cache_tensor, self.this_graph_name) sin_cache_tensor = helper.make_tensor( name=sin_cache_name, data_type=TensorProto.FLOAT, dims=list(sin_cache.shape), vals=sin_cache.flatten().tolist(), ) self.model.add_initializer(sin_cache_tensor, self.this_graph_name) self.nodes_to_remove.extend([cos_cache_node[0], sin_cache_node[0]]) rotary_emb_node = helper.make_node( self.base_name, inputs=[root_input, position_ids, cos_cache_name, sin_cache_name], outputs=[output], name=rotary_emb_node_name, interleaved=0, ) rotary_emb_node.domain = "com.microsoft" return rotary_emb_node def fuse(self, node, input_name_to_nodes, output_name_to_node): # Node is either RotaryEmbedding function or Add if self.base_name not in node.op_type and node.op_type != "Add": return # Check if node is "RotaryEmbedding nn.Module" exported as a function # (e.g. export_modules_as_functions={RotaryEmbedding} in torch.onnx.export) rotary_emb_node = None if node.op_type != "Add": # Verify that function has the correct inputs if len(node.input) not in {4, 5} or node.input[1] not in { "pos", "pos_id", "position_id", "pos_ids", "position_ids", }: logger.debug("fuse_rotary_embeddings: failed to verify inputs for RotaryEmbedding function") return rotary_emb_node = self.create_rotary_embeddings_from_function(node) if rotary_emb_node is None: logger.debug("fuse_rotary_embeddings: failed to create RotaryEmbedding node") return # Remove RotaryEmbedding function self.nodes_to_remove.append(node) # Remove RotaryEmbedding function's shape inference stored in value_info # The new shape will be calculated during symbolic shape inference old_shape_infer = list( filter(lambda node: node.name == rotary_emb_node.output[0], self.model.model.graph.value_info) ) assert len(old_shape_infer) == 1 self.model.model.graph.value_info.remove(old_shape_infer[0]) else: # Rotary embeddings are defined using the below functions: # # def rotate_half(x): # """Rotates half the hidden dims of the input.""" # x1 = x[..., : x.shape[-1] // 2] # x2 = x[..., x.shape[-1] // 2 :] # return torch.cat((-x2, x1), dim=-1) # # def apply_rope(x, cos, sin, position_ids): # cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] # sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] # cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] # sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] # x_embed = (x * cos) + (rotate_half(x) * sin) # return x_embed # Check paths for rotate_half(x) rotate_half_x2_path_1_1 = self.model.match_parent_path( node, ["Mul", "Concat", "Neg", "Slice", "Transpose"], [1, 0, 0, 0, 0], ) rotate_half_x2_path_1_2 = self.model.match_parent_path( node, ["Mul", "Concat", "Neg", "Slice", "Slice"], [1, 0, 0, 0, 0], ) rotate_half_x2_path_1 = rotate_half_x2_path_1_1 or rotate_half_x2_path_1_2 rotate_half_x2_path_2_1 = self.model.match_parent_path( node, ["Mul", "Concat", "Neg", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"], [1, 0, 0, 0, 1, 0, 0, 0, 0], ) rotate_half_x2_path_2_2 = self.model.match_parent_path( node, ["Mul", "Concat", "Neg", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Slice"], [1, 0, 0, 0, 1, 0, 0, 0, 0], ) rotate_half_x2_path_2 = rotate_half_x2_path_2_1 or rotate_half_x2_path_2_2 if rotate_half_x2_path_1 is None or rotate_half_x2_path_2 is None: logger.debug("fuse_rotary_embeddings: failed to match x2 in rotate_half") return rotate_half_x1_path_1_1 = self.model.match_parent_path( node, ["Mul", "Concat", "Slice", "Transpose"], [1, 0, 1, 0], ) rotate_half_x1_path_1_2 = self.model.match_parent_path( node, ["Mul", "Concat", "Slice", "Slice"], [1, 0, 1, 0], ) rotate_half_x1_path_1 = rotate_half_x1_path_1_1 or rotate_half_x1_path_1_2 rotate_half_x1_path_2_1 = self.model.match_parent_path( node, ["Mul", "Concat", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"], [1, 0, 1, 2, 0, 0, 0, 0], ) rotate_half_x1_path_2_2 = self.model.match_parent_path( node, ["Mul", "Concat", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Slice"], [1, 0, 1, 2, 0, 0, 0, 0], ) rotate_half_x1_path_2 = rotate_half_x1_path_2_1 or rotate_half_x1_path_2_2 if rotate_half_x1_path_1 is None or rotate_half_x1_path_2 is None: logger.debug("fuse_rotary_embeddings: failed to match x1 in rotate_half") return if ( rotate_half_x1_path_1[-1].name != rotate_half_x1_path_2[-1].name or rotate_half_x2_path_1[-1].name != rotate_half_x2_path_2[-1].name or rotate_half_x1_path_1[-1].name != rotate_half_x2_path_1[-1].name or rotate_half_x1_path_2[-1].name != rotate_half_x2_path_2[-1].name ): logger.debug("fuse_rotary_embeddings: failed to match common input in rotate_half") return # Check path for x x_path_1 = self.model.match_parent_path( node, ["Mul", "Transpose"], [0, 0], ) x_path_2 = self.model.match_parent_path( node, ["Mul", "Slice"], [0, 0], ) x_path = x_path_1 or x_path_2 if x_path is None: logger.debug("fuse_rotary_embeddings: failed to match x in rotate_half") return # Check path for sin sin_path, sin_cache, position_ids = None, "", "" sin_path_1 = self.model.match_parent_path( node, ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Gather", "Shape"], [1, 1, 0, 0, 0, 0, 2, 0, 0], ) sin_path_2 = self.model.match_parent_path( node, ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Add"], [1, 1, 0, 0, 0, 0, 2, 0], ) sin_path_3 = self.model.match_parent_path( node, ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Gather", "Shape"], [1, 1, 0, 0, 2, 0, 0], ) sin_path_4 = self.model.match_parent_path( node, ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Add"], [1, 1, 0, 0, 2, 0], ) if sin_path_1 is not None: sin_path = sin_path_1 sin_cache = sin_path[-4].input[0] elif sin_path_2 is not None: sin_path = sin_path_2 sin_cache = sin_path[-3].input[0] elif sin_path_3 is not None: sin_path = sin_path_3 sin_cache = sin_path[-4].input[0] position_ids = sin_path[2].input[1] elif sin_path_4 is not None: sin_path = sin_path_4 sin_cache = sin_path[-3].input[0] position_ids = sin_path[2].input[1] else: logger.debug("fuse_rotary_embeddings: failed to match sin path in apply_rope") return # Check path for cos cos_path, cos_cache = None, "" cos_path_1 = self.model.match_parent_path( node, ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Gather", "Shape"], [0, 1, 0, 0, 0, 0, 2, 0, 0], ) cos_path_2 = self.model.match_parent_path( node, ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Add"], [0, 1, 0, 0, 0, 0, 2, 0], ) cos_path_3 = self.model.match_parent_path( node, ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Gather", "Shape"], [0, 1, 0, 0, 2, 0, 0], ) cos_path_4 = self.model.match_parent_path( node, ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Add"], [0, 1, 0, 0, 2, 0], ) if cos_path_1 is not None: cos_path = cos_path_1 cos_cache = cos_path[-4].input[0] elif cos_path_2 is not None: cos_path = cos_path_2 cos_cache = cos_path[-3].input[0] elif cos_path_3 is not None: cos_path = cos_path_3 cos_cache = cos_path[-4].input[0] position_ids = cos_path[2].input[1] elif cos_path_4 is not None: cos_path = cos_path_4 cos_cache = cos_path[-3].input[0] position_ids = cos_path[2].input[1] else: logger.debug("fuse_rotary_embeddings: failed to match sin path in apply_rope") return # Check path for position ids if position_ids == "": position_ids_from_sin_path = self.model.match_parent_path( sin_path[2], ["Reshape"], [1], ) position_ids_from_cos_path = self.model.match_parent_path( cos_path[2], ["Reshape"], [1], ) if ( position_ids_from_sin_path is None or position_ids_from_cos_path is None or position_ids_from_sin_path[0].name != position_ids_from_cos_path[0].name ): logger.debug("fuse_rotary_embeddings: failed to match position ids path in apply_rope") return position_ids = position_ids_from_cos_path[0].input[0] else: position_ids_from_sin_path = [] position_ids_from_cos_path = [] past_seq_len_path, curr_seq_len_path = None, None if (sin_path == sin_path_1 and cos_path == cos_path_1) or ( sin_path == sin_path_3 and cos_path == cos_path_3 ): if sin_path[-2].name != cos_path[-2].name or sin_path[-1].name != cos_path[-1].name: logger.debug( "fuse_rotary_embeddings: failed to match common Gather node and Shape node in sin cache and cos cache" ) return elif (sin_path == sin_path_2 and cos_path == cos_path_2) or ( sin_path == sin_path_4 and cos_path == cos_path_4 ): if sin_path[-1].name != cos_path[-1].name: logger.debug("fuse_rotary_embeddings: failed to match common Add node in sin cache and cos cache") return # Match past sequence length path: past_key --> Shape --> Gather --> Add past_seq_len_path = self.model.match_parent_path( sin_path[-1], ["Gather", "Shape"], [1, 0], ) # Match current sequence length path: transpose_k --> Shape --> Gather --> Add curr_seq_len_path = self.model.match_parent_path( sin_path[-1], ["Gather", "Shape", "Transpose"], [0, 0, 0], ) if ( past_seq_len_path is None or curr_seq_len_path is None or self.model.find_graph_input(past_seq_len_path[-1].input[0]) is None or curr_seq_len_path[-1].op_type != "Transpose" ): logger.debug("fuse_rotary_embeddings: failed to match past_seq_len and curr_seq_len paths") return else: logger.debug("fuse_rotary_embeddings: failed to match common cache paths") rotary_emb_node = self.create_rotary_embeddings_from_nodes( rotate_half_x1_path_1[-1].output[0], position_ids, cos_cache, sin_cache, node.output[0], ) if rotary_emb_node is None: logger.debug("fuse_rotary_embeddings: failed to create RotaryEmbedding node") return # Remove rotary embedding nodes self.add_nodes_to_remove([node]) self.add_nodes_to_remove(rotate_half_x1_path_1[:-1]) self.add_nodes_to_remove(rotate_half_x1_path_2[:-1]) self.add_nodes_to_remove(rotate_half_x2_path_1[:-1]) self.add_nodes_to_remove(rotate_half_x2_path_2[:-1]) self.add_nodes_to_remove(x_path[:-1]) self.add_nodes_to_remove(sin_path) self.add_nodes_to_remove(cos_path) self.add_nodes_to_remove(position_ids_from_sin_path[:-1]) self.add_nodes_to_remove(position_ids_from_cos_path[:-1]) if past_seq_len_path is not None and len(self.model.get_children(past_seq_len_path[0])) == 1: # In merged HF model, output of Gather in past_seq_len_path is used twice # for past_key_values.0.key and once for other past_key_values self.add_nodes_to_remove(past_seq_len_path) if curr_seq_len_path is not None: self.add_nodes_to_remove(curr_seq_len_path[:-1]) self.increase_counter(self.base_name) self.node_name_to_graph_name[rotary_emb_node.name] = self.this_graph_name self.nodes_to_add.append(rotary_emb_node) self.prune_graph = True
Memory