# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- import logging import numpy as np from fusion_attention import AttentionMask, FusionAttention from onnx import TensorProto, helper from onnx_model import OnnxModel logger = logging.getLogger(__name__) class FusionBartAttention(FusionAttention): """ Fuse Bart Attention subgraph into one Attention node. """ def __init__( self, model: OnnxModel, hidden_size: int, num_heads: int, attention_mask: AttentionMask, ): super().__init__(model, hidden_size, num_heads, attention_mask) def check_runtime_shape_path( self, reshape_qkv_2, reshape_qkv_1, reshape_q_2, reshape_k_2, reshape_v_2, root_input, ): concat_qkv_2_path = self.model.match_parent_path(reshape_qkv_2, ["Concat"], [1]) if concat_qkv_2_path is None: return False concat_qkv_2 = concat_qkv_2_path[0] reshape_qkv_2_path_1 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) reshape_qkv_2_path_2 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) if reshape_qkv_2_path_1 is None or reshape_qkv_2_path_2 is None: return False _, gather_1, shape_1 = reshape_qkv_2_path_1 _, gather_2, shape_2 = reshape_qkv_2_path_2 if shape_1.input[0] != root_input or shape_2.input[0] != root_input: return False reshape_qkv_1_path_1 = self.model.match_parent_path(reshape_qkv_1, ["Concat", "Unsqueeze", "Gather"], [1, 0, 0]) reshape_qkv_1_path_2 = self.model.match_parent_path(reshape_qkv_1, ["Concat", "Unsqueeze", "Gather"], [1, 2, 0]) if reshape_qkv_1_path_1 is None or reshape_qkv_1_path_2 is None: return False if reshape_qkv_1_path_1[-1].name != gather_1.name or reshape_qkv_1_path_2[-1].name != gather_2.name: return False reshape_q_2_path = self.model.match_parent_path(reshape_q_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0]) reshape_k_2_path = self.model.match_parent_path(reshape_k_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0]) reshape_v_2_path = self.model.match_parent_path(reshape_v_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0]) if reshape_q_2_path is None or reshape_k_2_path is None or reshape_v_2_path is None: return False mul_q = reshape_q_2_path[-1] mul_k = reshape_k_2_path[-1] mul_v = reshape_v_2_path[-1] gather_1_out = gather_1.output[0] if mul_q.input[0] != gather_1_out or mul_k.input[0] != gather_1_out or mul_v.input[0] != gather_1_out: return False return True def check_runtime_shape_path_openai( self, reshape_qkv_2, matmul_qkv, add_qk, matmul_qk, add_q, ): reshape_qkv_2_path = self.model.match_parent_path( reshape_qkv_2, ["Concat", "Slice", "Gather", "Shape"], [1, 0, 0, 0] ) if reshape_qkv_2_path is None: return False else: if reshape_qkv_2_path[-1].input[0] != matmul_qkv.output[0]: return False matmul_qk_path_1 = self.model.match_parent_path( matmul_qk, ["Mul", "Pow", "Cast", "Div", "Gather", "Shape"], [0, 1, 0, 0, 0, 0] ) matmul_qk_path_2 = self.model.match_parent_path( matmul_qk, ["Mul", "Pow", "Cast", "Div", "Gather", "Shape"], [1, 1, 0, 0, 0, 0] ) if matmul_qk_path_1 is None or matmul_qk_path_2 is None: return False mul_1 = matmul_qk_path_1[0] mul_2 = matmul_qk_path_2[0] if mul_1.input[1] != mul_2.input[1]: return False if matmul_qk_path_1[-1].input[0] != add_q.output[0] and matmul_qk_path_2[-1].input[0] != add_q.output[0]: return False # For decoder attentions only if add_qk is not None: add_qk_path = self.model.match_parent_path(add_qk, ["Slice"], [1]) if add_qk_path is None: return False slice_q_path_1 = self.model.match_parent_path( add_qk_path[0], ["Slice", "Unsqueeze", "Gather", "Shape"], [0, 2, 0, 0] ) slice_q_path_2 = self.model.match_parent_path(add_qk_path[0], ["Unsqueeze", "Gather", "Shape"], [2, 0, 0]) if slice_q_path_1 is None and slice_q_path_2 is None: return False _, unsqueeze_1, _, _ = slice_q_path_1 unsqueeze_2, _, _ = slice_q_path_2 if unsqueeze_1.input[0] != unsqueeze_2.input[0]: return False if slice_q_path_1[-1].input[0] != add_q.output[0] and slice_q_path_2[-1].input[0] != add_q.output[0]: return False return True def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # Track if fusion is occurring for OpenAI implementation of Whisper model_impl_openai = False # SkipLayerNormalization has two inputs, and one of them is the root input for attention. qkv_nodes = self.model.match_parent_path( normalize_node, ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], [1, 1, 0, 0, 0, 0], ) qkv_nodes_openai = self.model.match_parent_path( normalize_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], [1, 1, 0, 0, 0], ) if qkv_nodes is not None: ( add_out, matmul_out, reshape_qkv_2, transpose_qkv, reshape_qkv_1, matmul_qkv, ) = qkv_nodes elif qkv_nodes_openai is not None: qkv_nodes = qkv_nodes_openai ( add_out, matmul_out, reshape_qkv_2, transpose_qkv, matmul_qkv, ) = qkv_nodes # Set model implementation to openai model_impl_openai = True else: return other_inputs = [] for input in normalize_node.input: if input not in output_name_to_node: continue if input == qkv_nodes[0].output[0]: continue other_inputs.append(input) if len(other_inputs) != 1: return root_input = other_inputs[0] # Sometimes the input name to the attention MatMul nodes does not match the input name to the end # SkipLayerNormalization node (name saved in root_input). We find the true input name to the MatMul # nodes by getting the initial SkipLayerNormalization node and checking how many MatMul nodes are # children nodes for each of its output names. """ root_input +---------------------------------------------------+ | | | | SkipLayerNormalization --> Attention --> MatMul --> SkipLayerNormalization """ skip_layernorm = output_name_to_node[root_input] # For some attention blocks, the end SkipLayerNormalization node may point to an Add node whose # child is the LayerNormalization node. if skip_layernorm.op_type == "Add": skip_layernorm = self.model.get_children(skip_layernorm)[0] for output in skip_layernorm.output: if not output: continue children = input_name_to_nodes[output] children_types = [child.op_type for child in children] if children_types.count("MatMul") >= 1: root_input = output break graph_input_names = set([node.name for node in self.model.graph().input]) graph_output_names = set([node.name for node in self.model.graph().output]) v_nodes = self.model.match_parent_path( matmul_qkv, ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, None], ) v_nodes_openai = self.model.match_parent_path( matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None], ) v_nodes_with_past_self_attn = self.model.match_parent_path( # Decoder attention with past value concatenated before MatMul matmul_qkv, ["Reshape", "Concat", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 1, 0, 0, None], ) v_nodes_with_past_cross_attn = self.model.match_parent_path( # Decoder attention with past value directly used in MatMul matmul_qkv, ["Reshape"], [1], ) v_nodes_with_past_cross_attn_openai = self.model.match_parent_path( matmul_qkv, ["Transpose", "Reshape", "Reshape", "Transpose"], [1, 0, 0, 0], ) past_v, present_v = "", "" reshape_v_2, add_v = None, None if v_nodes is not None: (reshape_v_2, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes # For initial pass through encoder-decoder_with_past to get starting past values (beam search) present_v = transpose_v.output[0] elif v_nodes_openai is not None: v_nodes = v_nodes_openai (transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes # For initial pass through encoder-decoder_with_past to get starting past values (beam search) # Find the child path to access the correct present_v values # Openai impl provides present/past v values in 3D format # whereas ort MultiHeadAttention expects v values in 4D, hence the # additional Reshape and Transpose nodes are added # For encoder attention types # Add -> Reshape -> Transpose -> Present_V reshape_path = self.model.match_child_path( add_v, ["Reshape", "Transpose"], exclude=[reshape_v_1], ) # For decoder attention types # add_v_node Reshape <- Transpose <-Past_V # \ / # \ / # -> Concat <- # | # |--> Reshape -> Transpose -> Present_V concat_path = self.model.match_child_path(add_v, ["Concat", "Reshape", "Transpose"]) if reshape_path is not None: (_, transpose_add_v) = reshape_path if transpose_add_v.output[0] in graph_output_names: present_v = transpose_add_v.output[0] if concat_path is not None: (concat_v, _, transpose_concat_v) = concat_path if transpose_concat_v.output[0] in graph_output_names: present_v = transpose_concat_v.output[0] concat_nodes = self.model.match_parent_path(concat_v, ["Reshape", "Transpose"], [0, 0]) _, transpose_concat_v_in = concat_nodes past_v = transpose_concat_v_in.input[0] elif v_nodes_with_past_self_attn is not None: (reshape_v_2, concat_v, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes_with_past_self_attn v_nodes = v_nodes_with_past_self_attn past_v = concat_v.input[0] present_v = concat_v.output[0] elif ( v_nodes_with_past_cross_attn is not None and v_nodes_with_past_cross_attn[-1].input[0] in graph_input_names ): v_nodes = v_nodes_with_past_cross_attn past_v = v_nodes[-1].input[0] present_v = v_nodes[-1].output[0] if present_v not in graph_output_names: identity_node_v = list( filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_v]) ) present_v = identity_node_v[0].output[0] if len(identity_node_v) == 1 else "" elif ( v_nodes_with_past_cross_attn_openai is not None and v_nodes_with_past_cross_attn_openai[-1].input[0] in graph_input_names ): v_nodes = v_nodes_with_past_cross_attn_openai past_v = v_nodes[-1].input[0] present_v = v_nodes[-1].output[0] if present_v not in graph_output_names: identity_node_v = list( filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_v]) ) present_v = identity_node_v[0].output[0] if len(identity_node_v) == 1 else "" else: logger.debug("fuse_attention: failed to match v path") return past_v = past_v if past_v in graph_input_names else "" present_v = present_v if present_v in graph_output_names else "" qk_nodes_1 = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0]) qk_nodes_2 = self.model.match_parent_path( matmul_qkv, ["Softmax", "Reshape", "Add", "Reshape", "MatMul"], [0, 0, 0, 0, 0] ) qk_nodes_2_openai = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0]) add_qk = None if qk_nodes_1 is not None: _, matmul_qk = qk_nodes_1 qk_nodes = qk_nodes_1 elif qk_nodes_2 is not None: _, _, add_qk, _, matmul_qk = qk_nodes_2 qk_nodes = qk_nodes_2 elif qk_nodes_2_openai is not None: _, add_qk, matmul_qk = qk_nodes_2_openai qk_nodes = qk_nodes_2_openai else: return q_nodes = self.model.match_parent_path( matmul_qk, ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"], [0, 0, 0, 0, 0, 1], ) q_nodes_openai = self.model.match_parent_path( matmul_qk, ["Mul", "Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, 0, 1], ) reshape_q_2 = None if q_nodes is not None: reshape_q_2, transpose_q, reshape_q_1, mul_q, add_q, matmul_q = q_nodes elif q_nodes_openai is not None: q_nodes = q_nodes_openai mul_q, transpose_q, reshape_q_1, add_q, matmul_q = q_nodes else: return k_nodes_with_bias = self.model.match_parent_path( matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, 0, 1], ) k_nodes_with_bias_openai = self.model.match_parent_path( matmul_qk, ["Mul", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0], ) k_nodes_no_bias = self.model.match_parent_path( matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0, 0], ) k_nodes_no_bias_with_past_self_attn = self.model.match_parent_path( # Decoder attention with past key concatenated before MatMul matmul_qk, ["Transpose", "Reshape", "Concat", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 1, 0, 0], ) k_nodes_no_bias_with_past_cross_attn = self.model.match_parent_path( # Decoder attention with past key directly used in MatMul matmul_qk, ["Transpose", "Reshape"], [1, 0], ) k_nodes_no_bias_with_past_cross_attn_openai = self.model.match_parent_path( # Decoder attention with past key directly used in MatMul matmul_qk, ["Mul", "Transpose", "Reshape", "Reshape", "Transpose"], [1, 0, 0, 0, 0], ) past_k, present_k = "", "" reshape_k_2, reshape_k_1, matmul_k = None, None, None if k_nodes_with_bias is not None: _, reshape_k_2, transpose_k_1, reshape_k_1, add_k, matmul_k = k_nodes_with_bias k_nodes = k_nodes_with_bias elif k_nodes_with_bias_openai is not None: mul_k, transpose_k_1, reshape_k_1, matmul_k = k_nodes_with_bias_openai k_nodes = k_nodes_with_bias_openai present_k = matmul_k.output[0] # Find the child path to access the correct present_k values # Openai impl provides present/past k values in 3D format # whereas ort MultiHeadAttention expects k values in 4D, hence the # additional Reshape and Transpose nodes are added # For encoder attention types # Matmul -> Reshape -> Transpose -> Present_K reshape_path = self.model.match_child_path( matmul_k, ["Reshape", "Transpose"], exclude=[reshape_k_1], ) # For decoder attention types # matmul_k_node Reshape <- Transpose <- Past_K # \ / # \ / # -> Concat <- # | # |--> Reshape -> Transpose -> Present_K concat_path = self.model.match_child_path(matmul_k, ["Concat", "Reshape", "Transpose"]) if reshape_path is not None: (_, transpose_matmul_k) = reshape_path if transpose_matmul_k.output[0] in graph_output_names: present_k = transpose_matmul_k.output[0] if concat_path is not None: (concat_k, _, transpose_concat_k) = concat_path if transpose_concat_k.output[0] in graph_output_names: present_k = transpose_concat_k.output[0] concat_nodes = self.model.match_parent_path(concat_k, ["Reshape", "Transpose"], [0, 0]) _, transpose_concat_k_in = concat_nodes past_k = transpose_concat_k_in.input[0] elif k_nodes_no_bias is not None: _, reshape_k_2, transpose_k_1, reshape_k_1, matmul_k = k_nodes_no_bias k_nodes = k_nodes_no_bias # For initial pass through encoder-decoder_with_past to get starting past values (beam search) present_k = transpose_k_1.output[0] elif k_nodes_no_bias_with_past_self_attn is not None: _, reshape_k_2, concat_k, _, reshape_k_1, matmul_k = k_nodes_no_bias_with_past_self_attn k_nodes = k_nodes_no_bias_with_past_self_attn past_k = concat_k.input[0] present_k = concat_k.output[0] elif ( k_nodes_no_bias_with_past_cross_attn is not None and k_nodes_no_bias_with_past_cross_attn[-1].input[0] in graph_input_names ): k_nodes = k_nodes_no_bias_with_past_cross_attn past_k = k_nodes[-1].input[0] present_k = k_nodes[-1].output[0] if present_k not in graph_output_names: identity_node_k = list( filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_k]) ) present_k = identity_node_k[0].output[0] if len(identity_node_k) == 1 else "" elif ( k_nodes_no_bias_with_past_cross_attn_openai is not None and k_nodes_no_bias_with_past_cross_attn_openai[-1].input[0] in graph_input_names ): k_nodes = k_nodes_no_bias_with_past_cross_attn_openai past_k = k_nodes[-1].input[0] present_k = k_nodes[-1].output[0] if present_k not in graph_output_names: identity_node_k = list( filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_k]) ) present_k = identity_node_k[0].output[0] if len(identity_node_k) == 1 else "" else: return past_k = past_k if past_k in graph_input_names else "" present_k = present_k if present_k in graph_output_names else "" if k_nodes in (k_nodes_with_bias_openai, k_nodes_no_bias, k_nodes_no_bias_with_past_self_attn): # Create empty Add node for attention graph bias_dim = self.model.get_initializer(add_v.input[0]).dims[0] empty_bias_name = "empty_bias" empty_tensor = self.model.get_initializer(empty_bias_name) if empty_tensor is None: self.add_initializer( empty_bias_name, TensorProto.FLOAT, dims=[bias_dim], vals=np.array([0.0] * bias_dim, dtype=np.float32), ) add_name = self.model.create_node_name("Add") add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k_1.name], add_name) if ( model_impl_openai and not past_k and not self.check_runtime_shape_path_openai( reshape_qkv_2, matmul_qkv, add_qk, matmul_qk, add_q, ) ): return elif ( not model_impl_openai and not past_k and not self.check_runtime_shape_path( reshape_qkv_2, reshape_qkv_1, reshape_q_2, reshape_k_2, reshape_v_2, root_input, ) ): return three_root_inputs = past_k and past_v and matmul_k is None and "matmul_v" not in locals() one_root_input = ( not three_root_inputs and matmul_k.input[0] == root_input and matmul_q.input[0] == root_input and matmul_v.input[0] == root_input ) two_root_inputs = ( not three_root_inputs and matmul_q.input[0] == root_input and matmul_k.input[0] == matmul_v.input[0] and matmul_k.input[0] != matmul_q.input[0] ) # There are 5 types of attention: # 1) Encoder attention with one_root_input=True and qk_nodes=qk_nodes_1 # 2) Decoder attention with one_root_input=True and qk_nodes=qk_nodes_2 # 3) Decoder attention with past with one_root_input=True and qk_nodes=qk_nodes_1 and past_k=past_decoder_key and past_v=past_decoder_value # 4) Decoder cross attention with two_root_inputs=True and qk_nodes=qk_nodes_1 # 5) Decoder cross attention with past with three_root_inputs=True and qk_nodes=qk_nodes_1 encoder_attention = one_root_input and qk_nodes == qk_nodes_1 decoder_attention = one_root_input and qk_nodes in (qk_nodes_2, qk_nodes_2_openai) decoder_attention_with_past = ( (encoder_attention if not model_impl_openai else decoder_attention) and past_k and past_v ) decoder_cross_attention = two_root_inputs and qk_nodes == qk_nodes_1 decoder_cross_attention_with_past = three_root_inputs and qk_nodes == qk_nodes_1 # For decoder_attention, the attention mask needs to be included in the attention node mask_index = None if decoder_attention: mask_nodes_bart = self.model.match_parent_path( add_qk, ["Where"], [1], ) mask_nodes_whisper = self.model.match_parent_path( add_qk, ["Expand", "Unsqueeze", "Unsqueeze", "Where"], [1, 0, 0, 0], ) if mask_nodes_whisper is not None: mask_index = mask_nodes_whisper[0].output[-1] elif mask_nodes_bart is not None: mask_index = mask_nodes_bart[0].output[-1] if ( encoder_attention or decoder_attention or decoder_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past ): attention_last_node = reshape_qkv_2 num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q_1) if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0: logger.debug("fuse_attention: failed to detect num_heads or hidden_size") return new_node = None if decoder_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past: # Note: Decoder attention with past key and past value is fused as multihead attention # rather than attention because multihead attention supports separate past key and past # value whereas attention supports concatenated past key and past value. new_node = ( self.create_multihead_attention_node( matmul_q, matmul_k if decoder_cross_attention or decoder_attention_with_past else past_k, matmul_v if decoder_cross_attention or decoder_attention_with_past else past_v, add_q, add_k if decoder_cross_attention or decoder_attention_with_past else None, add_v if decoder_cross_attention or decoder_attention_with_past else None, num_heads, hidden_size, attention_last_node.output[0], past_k=past_k if decoder_attention_with_past else "", past_v=past_v if decoder_attention_with_past else "", present_k=present_k, present_v=present_v, packed_qkv=decoder_attention_with_past, ) if self.use_multi_head_attention else None ) else: # Temporarily set multihead attention flag to false use_multi_head_attention_ground_truth = self.use_multi_head_attention self.use_multi_head_attention = False new_node = self.create_attention_node( None, matmul_q, matmul_k, matmul_v, add_q, add_k, add_v, num_heads, hidden_size, root_input, attention_last_node.output[0], add_qk_str=mask_index if decoder_attention else None, past_k=past_k, past_v=past_v, present_k=present_k, present_v=present_v, ) self.use_multi_head_attention = use_multi_head_attention_ground_truth if new_node is None: return self.nodes_to_add.append(new_node) self.node_name_to_graph_name[new_node.name] = self.this_graph_name self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv]) self.nodes_to_remove.extend(qk_nodes) # When using multihead attention, keep MatMul nodes in original graph if decoder_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past: if q_nodes[-1].op_type == "MatMul": q_nodes.pop() if k_nodes[-1].op_type == "MatMul": k_nodes.pop() if v_nodes[-1].op_type == "MatMul": v_nodes.pop() if self.disable_multi_head_attention_bias and ( decoder_cross_attention or decoder_cross_attention_with_past ): if q_nodes[-1].op_type == "Add": q_nodes.pop() if k_nodes[-1].op_type == "Add": k_nodes.pop() if v_nodes[-1].op_type == "Add": v_nodes.pop() self.nodes_to_remove.extend(q_nodes) self.nodes_to_remove.extend(k_nodes) self.nodes_to_remove.extend(v_nodes) # Use prune graph to remove mask nodes since they are shared by all attention nodes. self.prune_graph = True
Memory