# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from typing import Optional, Union
from fusion_attention import FusionAttention
from fusion_base import Fusion
from onnx import FunctionProto, NodeProto, TensorProto, helper, numpy_helper
from onnx_model import OnnxModel
logger = logging.getLogger(__name__)
class FusionRotaryAttention(FusionAttention):
"""
Fuse Attention subgraph with rotary positional embeddings into one MultiHeadAttention node.
"""
def __init__(
self,
model: OnnxModel,
hidden_size: int,
num_heads: int,
):
super().__init__(
model,
hidden_size,
num_heads,
use_multi_head_attention=True,
search_op_types=[
"SimplifiedLayerNormalization",
"SkipSimplifiedLayerNormalization",
"LayerNormalization",
"SkipLayerNormalization",
"Add",
],
)
def create_mha_node(
self,
input: str,
output: str,
q_rotary: NodeProto,
k_rotary: NodeProto,
v_matmul: NodeProto,
attn_mask: str = "",
add_qk: str = "",
past_k: str = "",
past_v: str = "",
present_k: str = "",
present_v: str = "",
scale: Optional[float] = None,
) -> Union[NodeProto, None]:
assert self.num_heads > 0
if self.hidden_size > 0 and (self.hidden_size % self.num_heads) != 0:
logger.debug(
f"fuse_rotary_attention: input hidden size {self.hidden_size} is not a multiple of num of heads {self.num_heads}"
)
return None
mha_node_name = self.model.create_node_name("MultiHeadAttention")
mha_inputs = [
q_rotary.output[0],
k_rotary.output[0],
v_matmul.output[0],
"", # bias
attn_mask, # key_padding_mask
add_qk, # attention_bias
past_k,
past_v,
]
mha_outputs = [output]
if present_k and present_v:
mha_outputs.extend([present_k, present_v])
mha_node = helper.make_node(
"MultiHeadAttention",
inputs=mha_inputs,
outputs=mha_outputs,
name=mha_node_name,
)
mha_node.domain = "com.microsoft"
mha_node.attribute.extend([helper.make_attribute("num_heads", self.num_heads)])
if scale is not None:
mha_node.attribute.extend([helper.make_attribute("scale", scale)])
if self.mask_filter_value is not None:
mha_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
self.increase_counter("MultiHeadAttention")
return mha_node
def check_runtime_shape_paths_for_function(
self,
reshape_qkv_2, # Reshape after Transpose
reshape_qkv_1, # Reshape before Transpose
reshape_q_2, # Reshape after RotaryEmbedding
reshape_k_2, # Reshape after RotaryEmbedding
reshape_v_2, # Reshape after Transpose
reshape_v_1, # Reshape before Transpose
add_qk, # Add before Softmax
root_input, # Root input to attention subgraph
):
# Check #1: check paths for qkv nodes
concat_qkv_2_path = self.model.match_parent_path(reshape_qkv_2, ["Concat"], [1])
concat_qkv_1_path = self.model.match_parent_path(reshape_qkv_1, ["Concat"], [1])
if concat_qkv_2_path is None or concat_qkv_1_path is None:
return False
concat_qkv_2, concat_qkv_1 = concat_qkv_2_path[0], concat_qkv_1_path[0]
reshape_qkv_2_path_1 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
reshape_qkv_2_path_2 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
reshape_qkv_1_path_1 = self.model.match_parent_path(concat_qkv_1, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
reshape_qkv_1_path_2 = self.model.match_parent_path(concat_qkv_1, ["Unsqueeze", "Gather", "Shape"], [2, 0, 0])
if (
reshape_qkv_2_path_1 is None
or reshape_qkv_2_path_2 is None
or reshape_qkv_1_path_1 is None
or reshape_qkv_1_path_2 is None
):
return False
_, gather_1, shape_1 = reshape_qkv_2_path_1
_, gather_2, shape_2 = reshape_qkv_2_path_2
# Check root_input --> Shape --> Gather connection
if shape_1.input[0] != root_input or shape_2.input[0] != root_input:
return False
# Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_qkv_1_path_1 and reshape_qkv_1_path_2
if reshape_qkv_1_path_1[1].name != gather_1.name or reshape_qkv_1_path_2[1].name != gather_2.name:
return False
# Check #2: check paths for v nodes
concat_v_2_path = self.model.match_parent_path(reshape_v_2, ["Concat"], [1])
concat_v_1_path = self.model.match_parent_path(reshape_v_1, ["Concat"], [1])
if concat_v_2_path is None or concat_v_1_path is None:
return False
concat_v_2, concat_v_1 = concat_v_2_path[0], concat_v_1_path[0]
reshape_v_2_path_1 = self.model.match_parent_path(
concat_v_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0]
)
reshape_v_2_path_2 = self.model.match_parent_path(
concat_v_2, ["Unsqueeze", "Add", "Gather", "Shape"], [1, 0, 0, 0]
)
reshape_v_1_path_1 = self.model.match_parent_path(concat_v_1, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
reshape_v_1_path_2 = self.model.match_parent_path(concat_v_1, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
if (
reshape_v_2_path_1 is None
or reshape_v_2_path_2 is None
or reshape_v_1_path_1 is None
or reshape_v_1_path_2 is None
):
return False
# Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_v_2_path_1
# Check Gather --> Add --> Unsqueeze --> Concat --> Reshape connection for reshape_v_2_path_2
# Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_v_1_path_1 and reshape_v_1_path_2
if (
reshape_v_2_path_1[2].name != gather_1.name
or reshape_v_2_path_2[2].name != gather_2.name
or reshape_v_1_path_1[1].name != gather_1.name
or reshape_v_1_path_2[1].name != gather_2.name
):
return False
# Check #3: check paths for k nodes
concat_k_2_path = self.model.match_parent_path(reshape_k_2, ["Concat"], [1])
if concat_k_2_path is None:
return False
concat_k_2 = concat_k_2_path[0]
reshape_k_2_path_1 = self.model.match_parent_path(
concat_k_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0]
)
reshape_k_2_path_2 = self.model.match_parent_path(
concat_k_2, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 0, 0]
)
if reshape_k_2_path_1 is None or reshape_k_2_path_2 is None:
return False
# Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_k_2_path_1
# Check Gather --> Add --> Unsqueeze --> Concat --> Reshape connection for reshape_k_2_path_2
if reshape_k_2_path_1[2].name != gather_1.name or reshape_k_2_path_2[2].name != gather_2.name:
return False
# Check #4: check paths for q nodes
concat_q_2_path = self.model.match_parent_path(reshape_q_2, ["Concat"], [1])
if concat_q_2_path is None:
return False
concat_q_2 = concat_q_2_path[0]
reshape_q_2_path_1 = self.model.match_parent_path(
concat_q_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0]
)
reshape_q_2_path_2 = self.model.match_parent_path(concat_q_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
if reshape_q_2_path_1 is None or reshape_q_2_path_2 is None:
return False
# Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_q_2_path_1
# Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_q_2_path_2
if reshape_q_2_path_1[2].name != gather_1.name or reshape_q_2_path_2[1].name != gather_2.name:
return False
# Check #5: check Mul nodes are the same for q, k, v
mul_q = reshape_q_2_path_1[1]
mul_k = reshape_k_2_path_1[1]
mul_v = reshape_v_2_path_1[1]
gather_1_out = gather_1.output[0]
if mul_q.input[0] != gather_1_out or mul_k.input[0] != gather_1_out or mul_v.input[0] != gather_1_out:
return False
# Check #6: check paths for attention mask nodes
attn_mask_path_1 = self.model.match_parent_path(add_qk, ["Concat", "Slice", "Slice"], [1, 0, 0])
attn_mask_path_2 = self.model.match_parent_path(add_qk, ["Cast", "Concat", "Slice", "Slice"], [1, 0, 0, 0])
if attn_mask_path_1 is not None:
_, slice_qk_2, slice_qk_1 = attn_mask_path_1
elif attn_mask_path_2 is not None:
_, _, slice_qk_2, slice_qk_1 = attn_mask_path_2
else:
return False
# Check first input to Slice #1 is 3D attention mask of shape (B,S,T)
if slice_qk_1.input[0] not in {"attn_mask", "attention_mask"}:
return False
slice_qk_2_path = self.model.match_parent_path(
slice_qk_2, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 1, 0]
)
slice_qk_1_path_1 = self.model.match_parent_path(
slice_qk_1, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 1, 0]
)
slice_qk_1_path_2 = self.model.match_parent_path(slice_qk_1, ["Unsqueeze"], [1])
if slice_qk_2_path is None or slice_qk_1_path_1 is None or slice_qk_1_path_2 is None:
return False
# Check Gather --> Add --> Unsqueeze #3 --> Slice #2 connection for slice_qk_2_path
# Check Gather --> Add --> Unsqueeze #2 --> Slice #1 connection for slice_qk_1_path_1
if slice_qk_2_path[1].name != slice_qk_1_path_1[1].name or slice_qk_2_path[2].name != slice_qk_1_path_1[2].name:
return False
# Check Unsqueeze #1 --> Slice #1 connection for slice_qk_1_path_2
# Check if first input to Add and Unsqueeze #1 is position ids
if slice_qk_1_path_1[1].input[0] != slice_qk_1_path_2[0].input[0]:
return False
return True
def check_runtime_shape_paths_for_nodes(
self,
reshape_qkv, # Final reshape before o_proj MatMul
reshape_q, # Reshape before q_proj MatMul
reshape_k, # Reshape before k_proj MatMul
reshape_v, # Reshape before v_proj MatMul
root_input, # Root input to attention subgraph
):
# Check #1: check paths for qkv nodes
concat_qkv_path = self.model.match_parent_path(reshape_qkv, ["Concat"], [1])
if concat_qkv_path is None:
return False
concat_qkv = concat_qkv_path[0]
reshape_qkv_path_1 = self.model.match_parent_path(concat_qkv, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
reshape_qkv_path_2 = self.model.match_parent_path(concat_qkv, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
if reshape_qkv_path_1 is None or reshape_qkv_path_2 is None:
return False
_, gather_1, shape_1 = reshape_qkv_path_1
_, gather_2, shape_2 = reshape_qkv_path_2
# Check root_input --> Shape --> Gather connection
if shape_1.input[0] != root_input or shape_2.input[0] != root_input:
return False
# Check #2: check paths for v nodes
concat_v_path = self.model.match_parent_path(reshape_v, ["Concat"], [1])
if concat_v_path is None:
return False
concat_v = concat_v_path[0]
reshape_v_path_1 = self.model.match_parent_path(concat_v, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
reshape_v_path_2 = self.model.match_parent_path(concat_v, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
if reshape_v_path_1 is None or reshape_v_path_2 is None:
return False
# Check Gather --> Unsqueeze --> Concat --> Reshape connection
if reshape_v_path_1[1].name != gather_1.name or reshape_v_path_2[1].name != gather_2.name:
return False
# Check #3: check paths for k nodes
concat_k_path = self.model.match_parent_path(reshape_k, ["Concat"], [1])
if concat_k_path is None:
return False
concat_k = concat_k_path[0]
reshape_k_path_1 = self.model.match_parent_path(concat_k, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
reshape_k_path_2 = self.model.match_parent_path(concat_k, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
if reshape_k_path_1 is None or reshape_k_path_2 is None:
return False
# Check Gather --> Unsqueeze --> Concat --> Reshape connection
if reshape_k_path_1[1].name != gather_1.name or reshape_k_path_2[1].name != gather_2.name:
return False
# Check #4: check paths for q nodes
concat_q_path = self.model.match_parent_path(reshape_q, ["Concat"], [1])
if concat_q_path is None:
return False
concat_q = concat_q_path[0]
reshape_q_path_1 = self.model.match_parent_path(concat_q, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
reshape_q_path_2 = self.model.match_parent_path(concat_q, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
if reshape_q_path_1 is None or reshape_q_path_2 is None:
return False
# Check Gather --> Unsqueeze --> Concat --> Reshape connection
if reshape_q_path_1[1].name != gather_1.name or reshape_q_path_2[1].name != gather_2.name:
return False
return True
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
if normalize_node.op_type not in {"SkipSimplifiedLayerNormalization", "SkipLayerNormalization", "Add"}:
return
# qkv_nodes_1 is for LLaMA-2 Microsoft
# qkv_nodes_2 is for LLaMA-2 Hugging Face
# qkv_nodes_3 is for LLaMA-2 distribute Hugging Face model
qkv_nodes = None
qkv_nodes_1 = self.model.match_parent_path(
normalize_node,
["MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
[1, 0, 0, 0, 0],
)
qkv_nodes_2 = self.model.match_parent_path(
normalize_node,
["MatMul", "Reshape", "Transpose", "MatMul"],
[1, 0, 0, 0],
)
qkv_nodes_3 = self.model.match_parent_path(
normalize_node,
["AllReduce", "MatMul", "Reshape", "Transpose", "MatMul"],
[1, 0, 0, 0, 0],
)
if qkv_nodes_1 is not None:
_, reshape_qkv_2, _, reshape_qkv_1, matmul_qkv = qkv_nodes_1
qkv_nodes = qkv_nodes_1
elif qkv_nodes_2 is not None:
_, reshape_qkv, _, matmul_qkv = qkv_nodes_2
qkv_nodes = qkv_nodes_2
elif qkv_nodes_3 is not None:
_, _, reshape_qkv, _, matmul_qkv = qkv_nodes_3
qkv_nodes = qkv_nodes_3
else:
logger.debug("fuse_rotary_attention: failed to match qkv nodes")
return
# v_nodes_1 is for LLaMA-2 Microsoft
# v_nodes_3 is for LLaMA-2 Hugging Face
# v_nodes_4 is for LLaMA-2 70B model
# v_nodes_5 is for Phi-2 DirectML
past_v, present_v, past_seq_len = "", "", ""
v_nodes = None
add_v = None
v_nodes_1 = self.model.match_parent_path(
matmul_qkv,
["Reshape", "Transpose", "Concat", "Transpose", "Reshape", "MatMul"],
[1, 0, 0, 1, 0, 0],
)
v_nodes_2 = self.model.match_parent_path(
matmul_qkv,
["Concat", "Transpose", "Reshape", "MatMul"],
[1, 1, 0, 0],
)
v_nodes_3 = self.model.match_parent_path(
matmul_qkv,
["Transpose", "Reshape", "MatMul"],
[1, 0, 0],
)
_, v_nodes_4, _ = self.model.match_parent_paths_all(
matmul_qkv,
[
(
["Reshape", "Expand", "Unsqueeze", "Concat", "Transpose", "Reshape", "MatMul"],
[1, 0, 0, 0, 1, 0, 0],
),
(
[
"Reshape",
"Expand",
"Where",
"Equal",
"Reshape",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
"Concat",
"Transpose",
"Reshape",
"MatMul",
],
[1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
),
(
[
"Reshape",
"Expand",
"Where",
"Equal",
"Mul",
"ConstantOfShape",
"Shape",
"Reshape",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
"Concat",
"Transpose",
"Reshape",
"MatMul",
],
[1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0],
),
(
[
"Reshape",
"Expand",
"Where",
"ConstantOfShape",
"Shape",
"Reshape",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
"Concat",
"Transpose",
"Reshape",
"MatMul",
],
[1, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0],
),
(
[
"Reshape",
"Expand",
"Where",
"Reshape",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
"Concat",
"Transpose",
"Reshape",
"MatMul",
],
[1, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0],
),
(
["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"],
[1, 1, 0, 0, 0, 0, 1, 0, 0],
),
(
[
"Reshape",
"Concat",
"Unsqueeze",
"Mul",
"Gather",
"Shape",
"Concat",
"Transpose",
"Reshape",
"MatMul",
],
[1, 1, 1, 0, 0, 0, 0, 1, 0, 0],
),
(
["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"],
[1, 1, 2, 0, 0, 0, 1, 0, 0],
),
(
["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"],
[1, 1, 3, 0, 0, 0, 1, 0, 0],
),
],
output_name_to_node=None,
)
v_nodes_5 = self.model.match_parent_path(
matmul_qkv,
["Concat", "Transpose", "Reshape", "Add", "MatMul"],
[1, 1, 0, 0, 1],
)
if v_nodes_1 is not None:
reshape_v_2, _, concat_v, _, reshape_v_1, matmul_v = v_nodes_1
v_nodes = v_nodes_1
concat_v_path = self.model.match_parent_path(
concat_v,
["Slice", "Unsqueeze"],
[0, 2],
)
if concat_v_path is None:
logger.debug("fuse_rotary_attention: failed to match past/present concat in v path")
return
past_v = concat_v_path[0].input[0]
past_seq_len = concat_v_path[-1].input[0]
present_v = concat_v.output[0]
elif v_nodes_2 is not None:
concat_v, transpose_v, reshape_v, matmul_v = v_nodes_2
v_nodes = v_nodes_2
past_v = concat_v.input[0]
present_v = concat_v.output[0]
elif v_nodes_3 is not None:
transpose_v, reshape_v, matmul_v = v_nodes_3
v_nodes = v_nodes_3
present_v = transpose_v.output[0]
elif v_nodes_4 is not None and len(v_nodes_4) == 9:
concat_v, transpose_v, reshape_v, matmul_v = v_nodes_4[0][-4:]
v_nodes = v_nodes_4
past_v = concat_v.input[0]
present_v = concat_v.output[0]
elif v_nodes_5 is not None:
concat_v, transpose_v, reshape_v, add_v, matmul_v = v_nodes_5
matmul_v = add_v
v_nodes = v_nodes_5
past_v = concat_v.input[0]
present_v = concat_v.output[0]
else:
logger.debug("fuse_rotary_attention: failed to match v path")
return
qk_nodes = self.model.match_parent_path(
matmul_qkv,
["Softmax", "Add", "Div", "MatMul"],
[0, 0, 0, 0],
)
add_qk, matmul_qk = None, None
if qk_nodes is not None:
_, add_qk, _, matmul_qk = qk_nodes
else:
logger.debug("fuse_rotary_attention: failed to match qk nodes")
return
# attn_mask_nodes_1, attn_mask_nodes_2 are for LLaMA-2 Microsoft's 3D attention mask
# attn_mask_nodes_3, attn_mask_nodes_4 are for LLaMA-2 Hugging Face's 2D attention mask
# attn_mask_nodes_5, attn_mask_nodes_6 are for LLaMA-2 Microsoft's model for the DML EP
# attn_mask_nodes_7 is for LLaMA-2 Hugging Face's changes to the attention mask
attn_mask, add_qk_str = "", ""
attn_mask_nodes_1 = self.model.match_parent_path(
add_qk,
["Concat", "Slice", "Slice"],
[1, 0, 0],
)
attn_mask_nodes_2 = self.model.match_parent_path(
add_qk,
["Cast", "Concat", "Slice", "Slice"],
[1, 0, 0, 0],
)
attn_mask_nodes_3 = self.model.match_parent_path(
add_qk,
["Add", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
[1, 0, 2, 1, 0, 0, 0],
)
attn_mask_nodes_4 = self.model.match_parent_path(
add_qk,
["Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
[1, 2, 1, 0, 0, 0],
)
attn_mask_nodes_5 = self.model.match_parent_path(
add_qk,
["Expand", "Add", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
[1, 0, 0, 2, 1, 0, 0, 0],
)
attn_mask_nodes_6 = self.model.match_parent_path(
add_qk,
["Expand", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
[1, 0, 2, 1, 0, 0, 0],
)
attn_mask_nodes_7 = self.model.match_parent_path(
add_qk,
["Where", "Cast", "Where", "Cast", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
[1, 0, 0, 0, 0, 1, 0, 0, 0],
)
if attn_mask_nodes_1 is not None:
_, slice_mask_1, slice_mask_2 = attn_mask_nodes_1
attn_mask = slice_mask_1.output[0]
elif attn_mask_nodes_2 is not None:
_, _, slice_mask_1, slice_mask_2 = attn_mask_nodes_2
attn_mask = slice_mask_1.output[0]
elif attn_mask_nodes_3 is not None:
# Reshape from (B,1,S,T) to (B,N,S,T)
add_qk_str = self.reshape_add_qk(attn_mask_nodes_3[0].output[0])
elif attn_mask_nodes_4 is not None:
# Reshape from (B,1,S,T) to (B,N,S,T)
add_qk_str = self.reshape_add_qk(attn_mask_nodes_4[0].output[0])
elif attn_mask_nodes_5 is not None:
# The mask has already been reshaped to (B,N,S,T)
add_qk_str = attn_mask_nodes_5[0].output[0]
elif attn_mask_nodes_6 is not None:
# The mask has already been reshaped to (B,N,S,T)
add_qk_str = attn_mask_nodes_6[0].output[0]
elif attn_mask_nodes_7 is not None:
# Reshape from (B,1,S,T) to (B,N,S,T)
add_qk_str = self.reshape_add_qk(attn_mask_nodes_7[0].output[0])
else:
logger.debug("fuse_rotary_attention: failed to match attention mask nodes")
return
# k_nodes_1 is for LLaMA-2 Microsoft
# k_nodes_2 is for LLaMA-2 Hugging Face
# k_nodes_4 is for LLaMA-2 70B Hugging Face
past_k, present_k = "", ""
k_nodes = None
slice_k = None
concat_k_half = None
k_nodes_1 = self.model.match_parent_path(
matmul_qk,
["Reshape", "Transpose", "Concat", "Transpose", "RotaryEmbedding", "MatMul"],
[1, 0, 0, 1, 0, 0],
)
k_nodes_2 = self.model.match_parent_path(
matmul_qk,
["Transpose", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"],
[1, 0, 0, 0, 0],
)
k_nodes_3 = self.model.match_parent_path(
matmul_qk,
["Transpose", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"],
[1, 0, 1, 0, 0, 0],
)
_, k_nodes_4, _ = self.model.match_parent_paths_all(
matmul_qk,
[
(
[
"Transpose",
"Reshape",
"Expand",
"Unsqueeze",
"Concat",
"RotaryEmbedding",
"Transpose",
"Reshape",
"MatMul",
],
[1, 0, 0, 0, 0, 1, 0, 0, 0],
),
(
[
"Transpose",
"Reshape",
"Expand",
"Where",
"Equal",
"Reshape",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
"Concat",
"RotaryEmbedding",
"Transpose",
"Reshape",
"MatMul",
],
[1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
),
(
[
"Transpose",
"Reshape",
"Expand",
"Where",
"Equal",
"Mul",
"ConstantOfShape",
"Shape",
"Reshape",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
"Concat",
"RotaryEmbedding",
"Transpose",
"Reshape",
"MatMul",
],
[1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0],
),
(
[
"Transpose",
"Reshape",
"Expand",
"Where",
"ConstantOfShape",
"Shape",
"Reshape",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
"Concat",
"RotaryEmbedding",
"Transpose",
"Reshape",
"MatMul",
],
[1, 0, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0],
),
(
[
"Transpose",
"Reshape",
"Expand",
"Where",
"Reshape",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
"Concat",
"RotaryEmbedding",
"Transpose",
"Reshape",
"MatMul",
],
[1, 0, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0, 0],
),
(
[
"Transpose",
"Reshape",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
"Concat",
"RotaryEmbedding",
"Transpose",
"Reshape",
"MatMul",
],
[1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
),
(
[
"Transpose",
"Reshape",
"Concat",
"Unsqueeze",
"Mul",
"Gather",
"Shape",
"Concat",
"RotaryEmbedding",
"Transpose",
"Reshape",
"MatMul",
],
[1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0],
),
(
[
"Transpose",
"Reshape",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
"Concat",
"RotaryEmbedding",
"Transpose",
"Reshape",
"MatMul",
],
[1, 0, 1, 2, 0, 0, 0, 1, 0, 0, 0],
),
(
[
"Transpose",
"Reshape",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
"Concat",
"RotaryEmbedding",
"Transpose",
"Reshape",
"MatMul",
],
[1, 0, 1, 3, 0, 0, 0, 1, 0, 0, 0],
),
],
output_name_to_node=None,
)
k_nodes_5 = self.model.match_parent_path(
matmul_qk,
["Transpose", "Concat", "Concat", "RotaryEmbedding", "Slice", "Transpose", "Reshape", "Add", "MatMul"],
[1, 0, 1, 0, 0, 0, 0, 0, 1],
)
if k_nodes_1 is not None:
reshape_k_2, _, concat_k, _, rotary_k, matmul_k = k_nodes_1
k_nodes = k_nodes_1
concat_k_path = self.model.match_parent_path(
concat_k,
["Slice", "Unsqueeze"],
[0, 2],
)
if concat_k_path is None:
logger.debug("fuse_rotary_attention: failed to match past/present concat in k path")
return
past_k = concat_k_path[0].input[0]
shared_past_seq_len = concat_k_path[-1].input[0]
present_k = concat_k.output[0]
assert past_seq_len == shared_past_seq_len
elif k_nodes_2 is not None:
_, rotary_k, _, reshape_k, matmul_k = k_nodes_2
k_nodes = k_nodes_2
present_k = rotary_k.output[0]
elif k_nodes_3 is not None:
_, concat_k, rotary_k, _, reshape_k, matmul_k = k_nodes_3
k_nodes = k_nodes_3
past_k = concat_k.input[0]
present_k = concat_k.output[0]
elif k_nodes_4 is not None and len(k_nodes_4) == 9:
reshape_k, matmul_k = k_nodes_4[0][-2:]
concat_k, rotary_k = k_nodes_4[0][-5:-3]
k_nodes = k_nodes_4
past_k = concat_k.input[0]
present_k = concat_k.output[0]
elif k_nodes_5 is not None:
_, concat_k, concat_k_half, rotary_k, slice_k, _, reshape_k, _, matmul_k = k_nodes_5
k_nodes = k_nodes_5
past_k = concat_k.input[0]
present_k = concat_k.output[0]
else:
logger.debug("fuse_rotary_attention: failed to match k nodes")
return
# q_nodes_1 is for LLaMA-2 Microsoft
# q_nodes_2 is for LLaMA-2 Hugging Face
# q_nodes_3 is for Phi-2 DirectML
q_nodes = None
slice_q = None
concat_q_half = None
q_nodes_1 = self.model.match_parent_path(
matmul_qk,
["Reshape", "Transpose", "RotaryEmbedding", "MatMul"],
[0, 0, 0, 0],
)
q_nodes_2 = self.model.match_parent_path(
matmul_qk,
["RotaryEmbedding", "Transpose", "Reshape", "MatMul"],
[0, 0, 0, 0],
)
q_nodes_3 = self.model.match_parent_path(
matmul_qk,
["Concat", "RotaryEmbedding", "Slice", "Transpose", "Reshape", "Add", "MatMul"],
[0, 0, 0, 0, 0, 0, 1],
)
if q_nodes_1 is not None:
reshape_q_2, _, rotary_q, matmul_q = q_nodes_1
q_nodes = q_nodes_1
elif q_nodes_2 is not None:
rotary_q, _, reshape_q, matmul_q = q_nodes_2
q_nodes = q_nodes_2
elif q_nodes_3 is not None:
concat_q_half, rotary_q, slice_q, _, reshape_q, _, matmul_q = q_nodes_3
q_nodes = q_nodes_3
else:
logger.debug("fuse_rotary_attention: failed to match q nodes")
return
if matmul_q.input[0] != matmul_k.input[0] and matmul_k.input[0] != matmul_v.input[0]:
logger.debug("fuse_rotary_attention: failed to find the same root_input for q, k, v paths")
return
root_output = ""
if qkv_nodes == qkv_nodes_1:
if not self.check_runtime_shape_paths_for_function(
reshape_qkv_2,
reshape_qkv_1,
reshape_q_2,
reshape_k_2,
reshape_v_2,
reshape_v_1,
add_qk,
matmul_q.input[0],
):
logger.debug("fuse_rotary_attention: failed to verify runtime shape paths")
return
root_output = reshape_qkv_2.output[0]
elif qkv_nodes in (qkv_nodes_2, qkv_nodes_3):
if not self.check_runtime_shape_paths_for_nodes(
reshape_qkv,
reshape_q,
reshape_k,
reshape_v,
matmul_q.input[0],
):
logger.debug("fuse_rotary_attention: failed to verify runtime shape paths")
return
root_output = reshape_qkv.output[0]
# Rename inputs of rotary_q/k so it connects with output of matmul_q/k
# Before: MatMul --> Reshape --> Transpose --> RotaryEmbedding
# After: MatMul --> RotaryEmbedding
rotary_q.input[0] = slice_q.output[0] if slice_q else matmul_q.output[0]
rotary_k.input[0] = slice_k.output[0] if slice_k else matmul_k.output[0]
# Rename current output of rotary_k (present_key) so it doesn't match output of MHA (present_key)
if concat_q_half is None:
rotary_k.output[0] = rotary_k.name + "_output_0"
if qkv_nodes == qkv_nodes_3:
qkv_nodes = qkv_nodes[1:]
def create_hidden_size_concat_node(reshape_q):
"""Detect num_heads and hidden_size for ONNX model from phi-2
Args:
reshape_q (NodeProto): reshape node for q
Returns:
hidden_size_concat_node(NodeProto): Concat node to be used by reshape
"""
concat = self.model.match_parent(reshape_q, "Concat", 1)
if concat is None:
logger.debug("fuse_rotary_attention: failed to trace the concat node from reshape_q")
return None
# The shape is a tensor like [?, ?, num_heads, head_size]
num_head_constant_node = self.model.get_constant_value(concat.input[2])
head_size_constant_node = self.model.get_constant_value(concat.input[3])
if num_head_constant_node is None or head_size_constant_node is None:
logger.debug("fuse_rotary_attention: failed to get constant nodes of num_heads or head_size")
return None
num_head_value = num_head_constant_node[0]
head_size_value = head_size_constant_node[0]
hidden_size = num_head_value * head_size_value
hidden_size_initilizer = self.model.create_node_name("Initializer", name_prefix="hidden_size")
if self.model.get_initializer(hidden_size_initilizer) is None:
self.add_initializer(
name=hidden_size_initilizer,
data_type=TensorProto.INT64,
dims=[1],
vals=[hidden_size],
raw=False,
)
hidden_size_reshape_node_name = self.model.create_node_name("Concat", name_prefix="hidden_size_concat")
hidden_size_concat_node = helper.make_node(
"Concat",
inputs=[
concat.input[0],
concat.input[1],
hidden_size_initilizer,
],
outputs=[hidden_size_reshape_node_name + "output_0"],
name=hidden_size_reshape_node_name,
)
hidden_size_concat_node.attribute.extend([helper.make_attribute("axis", 0)])
return hidden_size_concat_node
# Add Tranpose and Reshape nodes for patial rotary embedding applied in phi-2 before passing into MHA
if concat_q_half and concat_k_half:
# Transpose the key output of rotary Embedding
k_transpose_node_name = self.model.create_node_name("Transpose")
k_tranpose_output_name = k_transpose_node_name + "_output_0"
k_transpose_node = helper.make_node(
"Transpose",
inputs=[concat_k_half.output[0]],
outputs=[k_tranpose_output_name],
name=k_transpose_node_name,
)
k_transpose_node.attribute.extend([helper.make_attribute("perm", [0, 2, 1, 3])])
# Transpose the query output of rotary Embedding
q_transpose_node_name = self.model.create_node_name("Transpose")
q_tranpose_output_name = q_transpose_node_name + "_output_0"
q_transpose_node = helper.make_node(
"Transpose",
inputs=[concat_q_half.output[0]],
outputs=[q_tranpose_output_name],
name=q_transpose_node_name,
)
q_transpose_node.attribute.extend([helper.make_attribute("perm", [0, 2, 1, 3])])
hidden_size_concat_node = create_hidden_size_concat_node(reshape_k)
if hidden_size_concat_node is None:
logger.debug("fuse_rotary_attention: failed to create hidden_size_concat_node")
return
# Reshape the Rotary Embedding output for key for 4D to 3D
concat_k_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="concat_k_half")
concat_k_reshape_node = helper.make_node(
"Reshape",
inputs=[k_transpose_node.output[0], hidden_size_concat_node.output[0]],
outputs=[concat_k_reshape_node_name + "_output_0"],
name=concat_k_reshape_node_name,
)
# Reshape the Rotary Embedding output for query from 4D to 3D
concat_q_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="concat_q_half")
concat_q_reshape_node = helper.make_node(
"Reshape",
inputs=[q_transpose_node.output[0], hidden_size_concat_node.output[0]],
outputs=[concat_q_reshape_node_name + "_output_0"],
name=concat_q_reshape_node_name,
)
rotary_k = concat_k_reshape_node
rotary_q = concat_q_reshape_node
self.nodes_to_add.append(hidden_size_concat_node)
self.nodes_to_add.append(k_transpose_node)
self.nodes_to_add.append(q_transpose_node)
self.nodes_to_add.append(concat_k_reshape_node)
self.nodes_to_add.append(concat_q_reshape_node)
self.node_name_to_graph_name[hidden_size_concat_node.name] = self.this_graph_name
self.node_name_to_graph_name[k_transpose_node.name] = self.this_graph_name
self.node_name_to_graph_name[q_transpose_node.name] = self.this_graph_name
self.node_name_to_graph_name[concat_k_reshape_node.name] = self.this_graph_name
self.node_name_to_graph_name[concat_q_reshape_node.name] = self.this_graph_name
new_node = self.create_mha_node(
matmul_q.input[0],
root_output,
rotary_q,
rotary_k,
matmul_v,
attn_mask,
add_qk_str,
past_k,
past_v,
present_k,
present_v,
)
if new_node is None:
logger.debug("fuse_rotary_attention: failed to create multi-head attention with rotary embeddings")
return
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
self.nodes_to_remove.extend(qkv_nodes[1:])
if v_nodes != v_nodes_4:
self.nodes_to_remove.extend(v_nodes[:-1] if add_v is None else v_nodes[:-2])
else:
nodes_to_keep = [v_nodes[0][-1]]
for temp_path in v_nodes:
self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep)
self.nodes_to_remove.extend(qk_nodes)
if k_nodes == k_nodes_1:
self.nodes_to_remove.extend(k_nodes[:-2])
elif k_nodes == k_nodes_2:
self.nodes_to_remove.append(k_nodes[0])
self.nodes_to_remove.append(k_nodes[2])
self.nodes_to_remove.append(k_nodes[3])
elif k_nodes == k_nodes_3:
self.nodes_to_remove.append(k_nodes[0])
self.nodes_to_remove.append(k_nodes[1])
self.nodes_to_remove.append(k_nodes[3])
self.nodes_to_remove.append(k_nodes[4])
elif k_nodes == k_nodes_5:
self.nodes_to_remove.append(k_nodes[0])
self.nodes_to_remove.append(k_nodes[1])
elif k_nodes == k_nodes_4:
nodes_to_keep = [k_nodes[0][-1], k_nodes[0][-4]]
for temp_path in k_nodes:
self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep)
if q_nodes == q_nodes_1:
self.nodes_to_remove.extend(q_nodes[:-2])
elif q_nodes == q_nodes_2:
self.nodes_to_remove.append(q_nodes[1])
self.nodes_to_remove.append(q_nodes[2])
self.prune_graph = True
class FusionRotaryEmbeddings(Fusion):
def __init__(self, model: OnnxModel):
self.base_name = "RotaryEmbedding"
super().__init__(model, self.base_name, [self.base_name, self.base_name + ".1", "Add"])
# The RotaryEmbedding function can have multiple extraneous constant outputs even though the function is supposed to produce only one output.
# This is a byproduct of a potential CSE bug when using `export_modules_as_functions` in the TorchScript exporter.
# To work around this issue, we set the extraneous constant values from the RotaryEmbedding function as initializers in the locations where they are actually used.
def reassign_extra_outputs(self, rot_emb_node: NodeProto, function: FunctionProto):
# Find extra outputs and Constant nodes attached to those outputs
extra_constants, extra_outputs = [], []
for fn_node in function.node:
if fn_node.op_type == "Constant" and fn_node.input == [] and fn_node.output[0] in function.output:
extra_constants.append(fn_node)
output_index = list(function.output).index(fn_node.output[0])
extra_outputs.append(rot_emb_node.output[output_index])
# Set extra Constant node outputs as initializers
extra_initializers = []
for extra_constant in extra_constants:
constant_tensorproto = extra_constant.attribute[0].t
constant_tensorproto.name = self.model.create_node_name("Constant")
self.model.add_initializer(constant_tensorproto)
extra_initializers.append(constant_tensorproto.name)
# Update references of Constant node outputs to initializer references
for extra_output, extra_initializer in zip(extra_outputs, extra_initializers):
nodes_to_update = list(filter(lambda entry: extra_output in entry.input, self.model.model.graph.node))
for node_to_update in nodes_to_update:
OnnxModel.replace_node_input(node_to_update, extra_output, extra_initializer)
return extra_outputs
def create_rotary_embeddings_from_function(self, node: NodeProto):
rotary_emb_node_name = self.model.create_node_name(self.base_name)
matmul_path = self.model.match_parent_path(
node,
["Reshape", "MatMul"],
[0, 0],
)
if matmul_path is not None:
reshape_node, matmul_node = matmul_path
else:
logger.debug("fuse_rotary_embeddings: failed to match MatMul")
return
rotary_emb_inputs = [
matmul_node.output[0], # x is of shape (B,S,D) instead of (B,S,N,H)
node.input[1], # position_ids
]
# Convert cos_cache and sin_cache from node attributes to model initializers
cos_cache_node = list(filter(lambda constant: constant.output[0] == node.input[2], self.model.model.graph.node))
sin_cache_node = list(filter(lambda constant: constant.output[0] == node.input[3], self.model.model.graph.node))
cos_cache_name, sin_cache_name = "cos_cache", "sin_cache"
if (
len(cos_cache_node) == 1
and len(sin_cache_node) == 1
and self.model.get_initializer(cos_cache_name) is None
and self.model.get_initializer(sin_cache_name) is None
):
cos_cache = numpy_helper.to_array(cos_cache_node[0].attribute[0].t).squeeze()
sin_cache = numpy_helper.to_array(sin_cache_node[0].attribute[0].t).squeeze()
cos_cache_tensor = helper.make_tensor(
name=cos_cache_name,
data_type=TensorProto.FLOAT,
dims=list(cos_cache.shape),
vals=cos_cache.flatten().tolist(),
)
self.model.add_initializer(cos_cache_tensor, self.this_graph_name)
sin_cache_tensor = helper.make_tensor(
name=sin_cache_name,
data_type=TensorProto.FLOAT,
dims=list(sin_cache.shape),
vals=sin_cache.flatten().tolist(),
)
self.model.add_initializer(sin_cache_tensor, self.this_graph_name)
self.nodes_to_remove.extend([cos_cache_node[0], sin_cache_node[0]])
rotary_emb_inputs.extend([cos_cache_name, sin_cache_name])
rotary_emb_outputs = node.output
if len(rotary_emb_outputs) > 1:
# Re-assign extraneous constant outputs in RotaryEmbedding functions as initializers
func = list(filter(lambda fn: fn.name == node.op_type, self.model.model.functions))
assert len(func) == 1
extra_outputs = self.reassign_extra_outputs(node, func[0])
rotary_emb_outputs = list(filter(lambda output_name: output_name not in extra_outputs, rotary_emb_outputs))
assert len(rotary_emb_outputs) == 1
rotary_emb_node = helper.make_node(
self.base_name,
inputs=rotary_emb_inputs,
outputs=rotary_emb_outputs,
name=rotary_emb_node_name,
interleaved=1,
)
rotary_emb_node.domain = "com.microsoft"
self.nodes_to_remove.append(reshape_node)
return rotary_emb_node
def create_rotary_embeddings_from_nodes(
self,
root_input: str,
position_ids: str,
cos_slice: str,
sin_slice: str,
output: str,
):
rotary_emb_node_name = self.model.create_node_name(self.base_name)
# Convert cos_cache and sin_cache from node attributes to model initializers
cos_cache_node = list(filter(lambda constant: constant.output[0] == cos_slice, self.model.model.graph.node))
sin_cache_node = list(filter(lambda constant: constant.output[0] == sin_slice, self.model.model.graph.node))
cos_cache_name, sin_cache_name = "cos_cache", "sin_cache"
if (
len(cos_cache_node) == 1
and len(sin_cache_node) == 1
and self.model.get_initializer(cos_cache_name) is None
and self.model.get_initializer(sin_cache_name) is None
):
cos_cache = numpy_helper.to_array(cos_cache_node[0].attribute[0].t).squeeze()
sin_cache = numpy_helper.to_array(sin_cache_node[0].attribute[0].t).squeeze()
# Reshape cos/sin cache from (M, H) to (M, H/2)
head_size = cos_cache.shape[1]
cos_cache = cos_cache[:, : (head_size // 2)]
sin_cache = sin_cache[:, : (head_size // 2)]
cos_cache_tensor = helper.make_tensor(
name=cos_cache_name,
data_type=TensorProto.FLOAT,
dims=list(cos_cache.shape),
vals=cos_cache.flatten().tolist(),
)
self.model.add_initializer(cos_cache_tensor, self.this_graph_name)
sin_cache_tensor = helper.make_tensor(
name=sin_cache_name,
data_type=TensorProto.FLOAT,
dims=list(sin_cache.shape),
vals=sin_cache.flatten().tolist(),
)
self.model.add_initializer(sin_cache_tensor, self.this_graph_name)
self.nodes_to_remove.extend([cos_cache_node[0], sin_cache_node[0]])
rotary_emb_node = helper.make_node(
self.base_name,
inputs=[root_input, position_ids, cos_cache_name, sin_cache_name],
outputs=[output],
name=rotary_emb_node_name,
interleaved=0,
)
rotary_emb_node.domain = "com.microsoft"
return rotary_emb_node
def fuse(self, node, input_name_to_nodes, output_name_to_node):
# Node is either RotaryEmbedding function or Add
if self.base_name not in node.op_type and node.op_type != "Add":
return
# Check if node is "RotaryEmbedding nn.Module" exported as a function
# (e.g. export_modules_as_functions={RotaryEmbedding} in torch.onnx.export)
rotary_emb_node = None
if node.op_type != "Add":
# Verify that function has the correct inputs
if len(node.input) not in {4, 5} or node.input[1] not in {
"pos",
"pos_id",
"position_id",
"pos_ids",
"position_ids",
}:
logger.debug("fuse_rotary_embeddings: failed to verify inputs for RotaryEmbedding function")
return
rotary_emb_node = self.create_rotary_embeddings_from_function(node)
if rotary_emb_node is None:
logger.debug("fuse_rotary_embeddings: failed to create RotaryEmbedding node")
return
# Remove RotaryEmbedding function
self.nodes_to_remove.append(node)
# Remove RotaryEmbedding function's shape inference stored in value_info
# The new shape will be calculated during symbolic shape inference
old_shape_infer = list(
filter(lambda node: node.name == rotary_emb_node.output[0], self.model.model.graph.value_info)
)
assert len(old_shape_infer) == 1
self.model.model.graph.value_info.remove(old_shape_infer[0])
else:
# Rotary embeddings are defined using the below functions:
#
# def rotate_half(x):
# """Rotates half the hidden dims of the input."""
# x1 = x[..., : x.shape[-1] // 2]
# x2 = x[..., x.shape[-1] // 2 :]
# return torch.cat((-x2, x1), dim=-1)
#
# def apply_rope(x, cos, sin, position_ids):
# cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
# sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
# cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
# sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
# x_embed = (x * cos) + (rotate_half(x) * sin)
# return x_embed
# Check paths for rotate_half(x)
rotate_half_x2_path_1_1 = self.model.match_parent_path(
node,
["Mul", "Concat", "Neg", "Slice", "Transpose"],
[1, 0, 0, 0, 0],
)
rotate_half_x2_path_1_2 = self.model.match_parent_path(
node,
["Mul", "Concat", "Neg", "Slice", "Slice"],
[1, 0, 0, 0, 0],
)
rotate_half_x2_path_1 = rotate_half_x2_path_1_1 or rotate_half_x2_path_1_2
rotate_half_x2_path_2_1 = self.model.match_parent_path(
node,
["Mul", "Concat", "Neg", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"],
[1, 0, 0, 0, 1, 0, 0, 0, 0],
)
rotate_half_x2_path_2_2 = self.model.match_parent_path(
node,
["Mul", "Concat", "Neg", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Slice"],
[1, 0, 0, 0, 1, 0, 0, 0, 0],
)
rotate_half_x2_path_2 = rotate_half_x2_path_2_1 or rotate_half_x2_path_2_2
if rotate_half_x2_path_1 is None or rotate_half_x2_path_2 is None:
logger.debug("fuse_rotary_embeddings: failed to match x2 in rotate_half")
return
rotate_half_x1_path_1_1 = self.model.match_parent_path(
node,
["Mul", "Concat", "Slice", "Transpose"],
[1, 0, 1, 0],
)
rotate_half_x1_path_1_2 = self.model.match_parent_path(
node,
["Mul", "Concat", "Slice", "Slice"],
[1, 0, 1, 0],
)
rotate_half_x1_path_1 = rotate_half_x1_path_1_1 or rotate_half_x1_path_1_2
rotate_half_x1_path_2_1 = self.model.match_parent_path(
node,
["Mul", "Concat", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"],
[1, 0, 1, 2, 0, 0, 0, 0],
)
rotate_half_x1_path_2_2 = self.model.match_parent_path(
node,
["Mul", "Concat", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Slice"],
[1, 0, 1, 2, 0, 0, 0, 0],
)
rotate_half_x1_path_2 = rotate_half_x1_path_2_1 or rotate_half_x1_path_2_2
if rotate_half_x1_path_1 is None or rotate_half_x1_path_2 is None:
logger.debug("fuse_rotary_embeddings: failed to match x1 in rotate_half")
return
if (
rotate_half_x1_path_1[-1].name != rotate_half_x1_path_2[-1].name
or rotate_half_x2_path_1[-1].name != rotate_half_x2_path_2[-1].name
or rotate_half_x1_path_1[-1].name != rotate_half_x2_path_1[-1].name
or rotate_half_x1_path_2[-1].name != rotate_half_x2_path_2[-1].name
):
logger.debug("fuse_rotary_embeddings: failed to match common input in rotate_half")
return
# Check path for x
x_path_1 = self.model.match_parent_path(
node,
["Mul", "Transpose"],
[0, 0],
)
x_path_2 = self.model.match_parent_path(
node,
["Mul", "Slice"],
[0, 0],
)
x_path = x_path_1 or x_path_2
if x_path is None:
logger.debug("fuse_rotary_embeddings: failed to match x in rotate_half")
return
# Check path for sin
sin_path, sin_cache, position_ids = None, "", ""
sin_path_1 = self.model.match_parent_path(
node,
["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Gather", "Shape"],
[1, 1, 0, 0, 0, 0, 2, 0, 0],
)
sin_path_2 = self.model.match_parent_path(
node,
["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Add"],
[1, 1, 0, 0, 0, 0, 2, 0],
)
sin_path_3 = self.model.match_parent_path(
node,
["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Gather", "Shape"],
[1, 1, 0, 0, 2, 0, 0],
)
sin_path_4 = self.model.match_parent_path(
node,
["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Add"],
[1, 1, 0, 0, 2, 0],
)
if sin_path_1 is not None:
sin_path = sin_path_1
sin_cache = sin_path[-4].input[0]
elif sin_path_2 is not None:
sin_path = sin_path_2
sin_cache = sin_path[-3].input[0]
elif sin_path_3 is not None:
sin_path = sin_path_3
sin_cache = sin_path[-4].input[0]
position_ids = sin_path[2].input[1]
elif sin_path_4 is not None:
sin_path = sin_path_4
sin_cache = sin_path[-3].input[0]
position_ids = sin_path[2].input[1]
else:
logger.debug("fuse_rotary_embeddings: failed to match sin path in apply_rope")
return
# Check path for cos
cos_path, cos_cache = None, ""
cos_path_1 = self.model.match_parent_path(
node,
["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Gather", "Shape"],
[0, 1, 0, 0, 0, 0, 2, 0, 0],
)
cos_path_2 = self.model.match_parent_path(
node,
["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Add"],
[0, 1, 0, 0, 0, 0, 2, 0],
)
cos_path_3 = self.model.match_parent_path(
node,
["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Gather", "Shape"],
[0, 1, 0, 0, 2, 0, 0],
)
cos_path_4 = self.model.match_parent_path(
node,
["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Add"],
[0, 1, 0, 0, 2, 0],
)
if cos_path_1 is not None:
cos_path = cos_path_1
cos_cache = cos_path[-4].input[0]
elif cos_path_2 is not None:
cos_path = cos_path_2
cos_cache = cos_path[-3].input[0]
elif cos_path_3 is not None:
cos_path = cos_path_3
cos_cache = cos_path[-4].input[0]
position_ids = cos_path[2].input[1]
elif cos_path_4 is not None:
cos_path = cos_path_4
cos_cache = cos_path[-3].input[0]
position_ids = cos_path[2].input[1]
else:
logger.debug("fuse_rotary_embeddings: failed to match sin path in apply_rope")
return
# Check path for position ids
if position_ids == "":
position_ids_from_sin_path = self.model.match_parent_path(
sin_path[2],
["Reshape"],
[1],
)
position_ids_from_cos_path = self.model.match_parent_path(
cos_path[2],
["Reshape"],
[1],
)
if (
position_ids_from_sin_path is None
or position_ids_from_cos_path is None
or position_ids_from_sin_path[0].name != position_ids_from_cos_path[0].name
):
logger.debug("fuse_rotary_embeddings: failed to match position ids path in apply_rope")
return
position_ids = position_ids_from_cos_path[0].input[0]
else:
position_ids_from_sin_path = []
position_ids_from_cos_path = []
past_seq_len_path, curr_seq_len_path = None, None
if (sin_path == sin_path_1 and cos_path == cos_path_1) or (
sin_path == sin_path_3 and cos_path == cos_path_3
):
if sin_path[-2].name != cos_path[-2].name or sin_path[-1].name != cos_path[-1].name:
logger.debug(
"fuse_rotary_embeddings: failed to match common Gather node and Shape node in sin cache and cos cache"
)
return
elif (sin_path == sin_path_2 and cos_path == cos_path_2) or (
sin_path == sin_path_4 and cos_path == cos_path_4
):
if sin_path[-1].name != cos_path[-1].name:
logger.debug("fuse_rotary_embeddings: failed to match common Add node in sin cache and cos cache")
return
# Match past sequence length path: past_key --> Shape --> Gather --> Add
past_seq_len_path = self.model.match_parent_path(
sin_path[-1],
["Gather", "Shape"],
[1, 0],
)
# Match current sequence length path: transpose_k --> Shape --> Gather --> Add
curr_seq_len_path = self.model.match_parent_path(
sin_path[-1],
["Gather", "Shape", "Transpose"],
[0, 0, 0],
)
if (
past_seq_len_path is None
or curr_seq_len_path is None
or self.model.find_graph_input(past_seq_len_path[-1].input[0]) is None
or curr_seq_len_path[-1].op_type != "Transpose"
):
logger.debug("fuse_rotary_embeddings: failed to match past_seq_len and curr_seq_len paths")
return
else:
logger.debug("fuse_rotary_embeddings: failed to match common cache paths")
rotary_emb_node = self.create_rotary_embeddings_from_nodes(
rotate_half_x1_path_1[-1].output[0],
position_ids,
cos_cache,
sin_cache,
node.output[0],
)
if rotary_emb_node is None:
logger.debug("fuse_rotary_embeddings: failed to create RotaryEmbedding node")
return
# Remove rotary embedding nodes
self.add_nodes_to_remove([node])
self.add_nodes_to_remove(rotate_half_x1_path_1[:-1])
self.add_nodes_to_remove(rotate_half_x1_path_2[:-1])
self.add_nodes_to_remove(rotate_half_x2_path_1[:-1])
self.add_nodes_to_remove(rotate_half_x2_path_2[:-1])
self.add_nodes_to_remove(x_path[:-1])
self.add_nodes_to_remove(sin_path)
self.add_nodes_to_remove(cos_path)
self.add_nodes_to_remove(position_ids_from_sin_path[:-1])
self.add_nodes_to_remove(position_ids_from_cos_path[:-1])
if past_seq_len_path is not None and len(self.model.get_children(past_seq_len_path[0])) == 1:
# In merged HF model, output of Gather in past_seq_len_path is used twice
# for past_key_values.0.key and once for other past_key_values
self.add_nodes_to_remove(past_seq_len_path)
if curr_seq_len_path is not None:
self.add_nodes_to_remove(curr_seq_len_path[:-1])
self.increase_counter(self.base_name)
self.node_name_to_graph_name[rotary_emb_node.name] = self.this_graph_name
self.nodes_to_add.append(rotary_emb_node)
self.prune_graph = True