# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import List, Optional
import numpy as np
from dynamo_onnx_helper import DynamoOnnxHelper
from fusion_base import Fusion
from fusion_options import AttentionOpType, FusionOptions
from fusion_skiplayernorm import FusionBiasSkipLayerNormalization, FusionSkipLayerNormalization
from fusion_utils import NumpyHelper
from onnx import ModelProto, NodeProto, TensorProto, helper, numpy_helper
from onnx_model import OnnxModel
logger = getLogger(__name__)
class ProcessGemmWFunc:
def __call__(self, x):
return np.transpose(x, (1, 0))
class ProcessMatMulQFunc:
def __call__(self, x):
return np.transpose(np.split(x, 3, 0)[0], (1, 0))
class ProcessMatMulKFunc:
def __call__(self, x):
return np.transpose(np.split(x, 3, 0)[1], (1, 0))
class ProcessMatMulVFunc:
def __call__(self, x):
return np.transpose(np.split(x, 3, 0)[2], (1, 0))
class ProcessBiasQFunc:
def __call__(self, x):
x = np.split(x, 3, -1)[0]
return x
class ProcessBiasKFunc:
def __call__(self, x):
x = np.split(x, 3, -1)[1]
return x
class ProcessBiasVFunc:
def __call__(self, x):
x = np.split(x, 3, -1)[2]
return x
class ProcessRotCacheFunc:
def __call__(self, x):
# half rotary embedding
assert len(x.shape) == 2
if x.shape[1] == 32:
return x[:, 0:16]
return x
# TODO: move to a separate file
class Fission(Fusion):
def __init__(
self,
model: OnnxModel,
nodes_to_find: List[str],
):
super().__init__(model, "DONOTUSE", nodes_to_find)
def set_attention_op_type(self, attn_op_type: AttentionOpType):
self.attn_op_type = attn_op_type
def get_uname(self, layer_id, name):
return name + "_" + str(layer_id)
def get_edge_by_name(self, edges, name):
for edge in edges:
if edge == name or edge.endswith(name) or edge.startswith(name):
return edge
raise ValueError(f"Edge {name} not found")
def get_input_by_name(self, node, name):
return self.get_edge_by_name(node.input, name)
def get_output_by_name(self, node, name):
return self.get_edge_by_name(node.output, name)
def process_initializer(self, initializer_name, functor, custom_name=None):
i = self.model.get_initializer(initializer_name)
i_np_array = NumpyHelper.to_array(i)
processed_i_np_array = functor(i_np_array)
new_tensor = helper.make_tensor(
initializer_name + "_processed" if custom_name is None else custom_name,
data_type=TensorProto.FLOAT,
dims=processed_i_np_array.shape,
vals=processed_i_np_array.flatten().tobytes(),
raw=True,
)
self.model.add_initializer(new_tensor, self.this_graph_name)
return new_tensor.name
def add_fp32_value_info(self, name):
new_value_info = self.model.graph().value_info.add()
new_value_info.name = name
new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT
def add_int64_value_info(self, name):
new_value_info = self.model.graph().value_info.add()
new_value_info.name = name
new_value_info.type.tensor_type.elem_type = TensorProto.INT64
def replace_fp32_value_info(self, name, shape):
for value_info in self.model.graph().value_info:
if value_info.name == name:
self.model.graph().value_info.remove(value_info)
break
new_value_info = helper.make_tensor_value_info(
name,
elem_type=TensorProto.FLOAT,
shape=shape,
)
self.model.graph().value_info.extend([new_value_info])
def set_unique_name_and_add_nodes(
self, subgraph_nodes: List[NodeProto], layer_id: int, layer_known_edges_names: List[str]
):
for new_node in subgraph_nodes:
for i, name in enumerate(new_node.input):
if name == "":
continue
elif name not in layer_known_edges_names:
new_node.input[i] = self.get_uname(layer_id, name)
self.add_fp32_value_info(new_node.input[i])
for i, name in enumerate(new_node.output):
if name == "":
continue
elif name not in layer_known_edges_names:
new_node.output[i] = self.get_uname(layer_id, name)
self.add_fp32_value_info(new_node.output[i])
new_node.name = self.get_uname(layer_id, new_node.name)
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
def layernorm(self, inputs: List[str], outputs: List[str], prefix: str = ""):
assert len(inputs) == 3
assert len(outputs) == 1
node = helper.make_node(
"LayerNormalization",
inputs=inputs,
outputs=outputs,
name=prefix + "_LayerNormalization",
epsilon=9.999999747378752e-06,
)
return [node]
def gemm(self, inputs: List[str], outputs: List[str], prefix: str = ""):
assert len(inputs) == 3
assert len(outputs) == 1
matmul = helper.make_node(
"MatMul",
inputs=[inputs[0], inputs[1]],
outputs=[prefix + "matmul_out"],
name=prefix + "MatMul",
)
add = helper.make_node(
"Add",
inputs=[prefix + "matmul_out", inputs[2]],
outputs=outputs,
name=prefix + "Bias",
)
return [matmul, add]
def rotary(self, inputs: List[str], outputs: List[str], prefix: str = "", rot_dim=32, num_heads=32):
assert len(inputs) == 4
assert len(outputs) == 1
node = helper.make_node(
"RotaryEmbedding",
inputs=inputs,
outputs=outputs,
name=prefix + "RotaryEmbedding",
domain="com.microsoft",
rotary_embedding_dim=rot_dim,
num_heads=num_heads,
)
return [node]
def fastgelu(self, inputs: List[str], outputs: List[str], prefix: str = ""):
assert len(inputs) == 1
assert len(outputs) == 1
node = helper.make_node(
"FastGelu",
inputs=inputs,
outputs=outputs,
name=prefix + "FastGelu",
domain="com.microsoft",
)
return [node]
def add(self, inputs: List[str], outputs: List[str], prefix: str = ""):
assert len(inputs) == 2
assert len(outputs) == 1
node = helper.make_node(
"Add",
inputs=inputs,
outputs=outputs,
name=prefix + "Add",
)
return [node]
def mha(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32):
assert len(inputs) == 8
assert len(outputs) == 3
node = helper.make_node(
"MultiHeadAttention",
inputs=inputs,
outputs=outputs,
name=prefix + "MultiHeadAttention",
domain="com.microsoft",
num_heads=num_heads,
unidirectional=1,
)
return [node]
def gqa(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32):
assert len(inputs) == 7
assert len(outputs) == 3
node = helper.make_node(
"GroupQueryAttention",
inputs=inputs,
outputs=outputs,
name=prefix + "GroupQueryAttention",
domain="com.microsoft",
num_heads=num_heads,
kv_num_heads=num_heads,
)
return [node]
def attention(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32):
assert len(inputs) == 5
assert len(outputs) == 2
node = helper.make_node(
"Attention",
inputs=inputs,
outputs=outputs,
name=prefix + "Attention",
domain="com.microsoft",
num_heads=num_heads,
unidirectional=1,
do_rotary=1,
rotary_embedding_dim=32,
)
return [node]
def paged_attn(
self,
inputs: List[str],
outputs: List[str],
prefix: str = "",
num_heads=32,
head_size=80,
scale=0.11180339753627777,
):
assert len(inputs) == 6
assert len(outputs) == 1
node = helper.make_node(
"PagedAttention",
inputs=inputs,
outputs=outputs,
name=prefix + "PagedAttention",
domain="vllm.ort.ext",
num_heads=num_heads,
num_kv_heads=num_heads,
head_size=head_size,
scale=scale,
)
return [node]
class Phi2PreProcessor(DynamoOnnxHelper):
def __init__(self, model: ModelProto, num_heads: int, hidden_size: int):
super().__init__(model)
self.num_hidden_layers = 32
self.num_attention_heads = num_heads
self.hidden_size = hidden_size
self.func_name = "modeling_phi_PhiModel_model_1"
def get_phi2_edge_dict(self) -> dict:
edge_dict = {}
edge_dict["lm_head_1"] = "logits"
edge_dict["l_input_ids_"] = "input_ids"
edge_dict["key_states"] = "past_key_0"
edge_dict["value_states"] = "past_value_0"
for i in range(1, self.num_hidden_layers, 1):
edge_dict[f"key_states_{i}"] = f"past_key_{i}"
edge_dict[f"value_states_{i}"] = f"past_value_{i}"
edge_dict[f"model_layers_{i}_1"] = f"present_key_{i}"
edge_dict[f"model_layers_{i}_1_1"] = f"present_value_{i}"
outputs = [o.name for o in self.model.graph.output]
if "model_layers_0_1_1" in outputs and "model_layers_0_1_2" in outputs:
edge_dict["model_layers_0_1_1"] = "present_key_0"
edge_dict["model_layers_0_1_2"] = "present_value_0"
else:
assert "model_layers_0_1" in outputs and "model_layers_0_1_1" in outputs
edge_dict["model_layers_0_1"] = "present_key_0"
edge_dict["model_layers_0_1_1"] = "present_value_0"
return edge_dict
def simplify_phi2_op_type(self):
phi2_transformer_layer_name = "modeling_phi_PhiDecoderLayer_model_layers"
for node in self.model.graph.node:
index = node.op_type.find(phi2_transformer_layer_name)
if index != -1:
node.op_type = node.op_type[index:]
def process_graph_io(self, attn_op_type: AttentionOpType):
self.use_attn = attn_op_type == AttentionOpType.Attention
self.use_vllm = attn_op_type == AttentionOpType.PagedAttention
graph = self.model.graph
new_inputs = []
for vi in graph.input:
if "input_ids" in vi.name:
vi_iid = helper.make_tensor_value_info(
vi.name,
elem_type=TensorProto.INT32 if not self.use_vllm else TensorProto.INT64,
shape=["batch_size", "seq_len"],
)
vi_step = helper.make_tensor_value_info(
"step",
elem_type=TensorProto.INT64,
shape=[1],
)
vi_pid = helper.make_tensor_value_info(
"position_ids",
elem_type=TensorProto.INT64,
shape=["batch_size", "seq_len"],
)
vi_mask = helper.make_tensor_value_info(
"attention_mask",
elem_type=TensorProto.INT32,
shape=["batch_size", "seq_len"],
)
vi_meta = helper.make_tensor_value_info(
"input_metadata",
elem_type=TensorProto.INT64,
shape=[1],
)
(
new_inputs.extend([vi_iid, vi_step, vi_mask])
if not self.use_vllm
else new_inputs.extend([vi_iid, vi_pid, vi_meta])
)
if self.use_attn:
if "past_key" in vi.name:
vi_cache = helper.make_tensor_value_info(
vi.name.replace("past_key", "past"),
elem_type=vi.type.tensor_type.elem_type,
shape=[
2,
"batch_size",
self.num_attention_heads,
"past_seq_len",
self.hidden_size // self.num_attention_heads,
],
)
new_inputs.extend([vi_cache])
elif self.use_vllm:
if "past_key" in vi.name:
vi_cache = helper.make_tensor_value_info(
vi.name,
elem_type=vi.type.tensor_type.elem_type,
shape=["num_blocks", "num_heads", "head_size_x", "block_size", "block_x"],
)
new_inputs.extend([vi_cache])
if "past_value" in vi.name:
vi_cache = helper.make_tensor_value_info(
vi.name,
elem_type=vi.type.tensor_type.elem_type,
shape=[
"num_blocks",
"num_heads",
"head_size",
"block_size",
],
)
new_inputs.extend([vi_cache])
else:
if "past_key" in vi.name or "past_value" in vi.name:
vi_cache = helper.make_tensor_value_info(
vi.name,
elem_type=vi.type.tensor_type.elem_type,
shape=[
"batch_size",
self.num_attention_heads,
"past_seq_len",
self.hidden_size // self.num_attention_heads,
],
)
new_inputs.extend([vi_cache])
graph.ClearField("input")
graph.input.extend(new_inputs)
new_outputs = []
for i, vi in enumerate(graph.output):
if i == 0:
new_outputs.extend([vi])
else:
if self.use_attn:
if "present_key" in vi.name:
vi_cache = helper.make_tensor_value_info(
vi.name.replace("present_key", "present"),
elem_type=vi.type.tensor_type.elem_type,
shape=[
2,
"batch_size",
self.num_attention_heads,
"total_seq_len",
self.hidden_size // self.num_attention_heads,
],
)
new_outputs.extend([vi_cache])
elif self.use_vllm:
pass
else:
vi_cache = helper.make_tensor_value_info(
vi.name,
elem_type=vi.type.tensor_type.elem_type,
shape=[
"batch_size",
self.num_attention_heads,
"total_seq_len",
self.hidden_size // self.num_attention_heads,
],
)
new_outputs.extend([vi_cache])
graph.ClearField("output")
graph.output.extend(new_outputs)
def preprocess_onnx(self, attn_op_type: AttentionOpType):
function_name = None
for func in self.model.functions:
if func.name.endswith(self.func_name):
function_name = func.name
break
assert function_name is not None
self.unroll_function(function_name)
self.update_edges(self.get_phi2_edge_dict())
self.simplify_phi2_op_type()
self.remove_dropout_layer()
if attn_op_type == AttentionOpType.PagedAttention:
self.remove_lm_head_layer()
self.process_graph_io(attn_op_type)
class FissionTransformerEmbeddingPhi(Fission):
def __init__(
self,
model: OnnxModel,
):
super().__init__(model, ["torch_nn_modules_sparse_Embedding_model_embed_tokens_1"])
def fuse(self, node, input_name_to_nodes, output_name_to_node):
logger.info("Optimizing %s...", node.name)
assert len(node.input) == 2
assert len(node.output) == 1
input = node.input[0]
output = node.output[0]
embedding = self.get_input_by_name(node, "embed_tokens.weight")
layer_known_edges_names = [input, output, embedding]
subgraph_nodes = [
helper.make_node(
"Gather",
inputs=[embedding, input],
outputs=[output],
name="Embedding_Gather",
),
]
self.set_unique_name_and_add_nodes(subgraph_nodes, 0, layer_known_edges_names)
self.nodes_to_remove.append(node)
self.prune_graph = True
class FissionTransformerLayerNormPhi(Fission):
def __init__(
self,
model: OnnxModel,
):
super().__init__(model, ["torch_nn_modules_normalization_LayerNorm_model_final_layernorm_1"])
def fuse(self, node, input_name_to_nodes, output_name_to_node):
logger.info("Optimizing %s...", node.name)
assert len(node.input) == 3
assert len(node.output) == 1
input = node.input[0]
output = node.output[0]
ln_weight = self.get_input_by_name(node, "final_layernorm.weight")
ln_bias = self.get_input_by_name(node, "final_layernorm.bias")
layer_known_edges_names = [input, output, ln_weight, ln_bias]
subgraph_nodes = []
subgraph_nodes.extend(self.layernorm([input, ln_weight, ln_bias], [output], "Final"))
self.set_unique_name_and_add_nodes(subgraph_nodes, 99, layer_known_edges_names)
self.replace_fp32_value_info(input, ["batch_size", "seq_len", "hidden_size"])
self.replace_fp32_value_info(output, ["batch_size", "seq_len", "hidden_size"])
self.nodes_to_remove.append(node)
self.prune_graph = True
class FissionTransformerCausalLMHeadPhi(Fission):
def __init__(
self,
model: OnnxModel,
):
super().__init__(model, ["torch_nn_modules_linear_Linear_lm_head_1"])
def fuse(self, node, input_name_to_nodes, output_name_to_node):
logger.info("Optimizing %s...", node.name)
assert len(node.input) == 5
assert len(node.output) == 1
input = node.input[2]
output = node.output[0]
fc_weight = self.process_initializer(self.get_input_by_name(node, "lm_head.weight"), ProcessGemmWFunc())
fc_bias = self.get_input_by_name(node, "lm_head.bias")
layer_known_edges_names = [input, output, fc_weight, fc_bias]
subgraph_nodes = []
subgraph_nodes.extend(self.gemm([input, fc_weight, fc_bias], [output], "LMHead_"))
self.set_unique_name_and_add_nodes(subgraph_nodes, 99, layer_known_edges_names)
self.replace_fp32_value_info(input, ["batch_size", "seq_len", "hidden_size"])
self.replace_fp32_value_info(output, ["batch_size", "seq_len", 51200])
self.nodes_to_remove.append(node)
self.prune_graph = True
class FissionTransformerBlockPhi(Fission):
def __init__(
self,
model: OnnxModel,
num_heads: int,
):
self.num_heads = num_heads
max_num_layers = 32
self.func_to_layer_id = {}
nodes_to_find = []
for layer in range(max_num_layers):
func_name = f"modeling_phi_PhiDecoderLayer_model_layers_{layer}_1"
nodes_to_find.append(func_name)
self.func_to_layer_id[func_name] = layer
super().__init__(model, nodes_to_find)
def get_layer_id(self, node):
return self.func_to_layer_id[node.op_type]
def get_gqa_aux_nodes(self):
gqa_aux_nodes = [
helper.make_node(
"Cast",
inputs=["attention_mask"],
outputs=["mask_int64"],
name="Cast_gqa_aux_0",
to=TensorProto.INT64,
),
helper.make_node(
"ReduceSum",
inputs=["mask_int64", "one"],
outputs=["mask_row_sums"],
name="ReduceSum_gqa_aux",
),
helper.make_node(
"Sub",
inputs=["mask_row_sums", "one"],
outputs=["seqlens_k_int64"],
name="Sub_gqa_aux",
),
helper.make_node(
"Cast",
inputs=["seqlens_k_int64"],
outputs=["seqlens_k"],
name="Cast_gqa_aux_1",
to=TensorProto.INT32,
),
helper.make_node("Shape", inputs=["mask_int64"], outputs=["mask_shape"], name="Shape_gqa_aux_0"),
helper.make_node(
"Gather",
inputs=["mask_shape", "one"],
outputs=["total_seq_len_int64"],
name="Gather_gqa_aux_0",
axis=0,
),
helper.make_node(
"Cast",
inputs=["total_seq_len_int64"],
outputs=["total_sequence_length"],
name="Cast_gqa_aux_2",
to=TensorProto.INT32,
),
]
return gqa_aux_nodes
def pack_qkv_gemm(self, q_w, k_w, v_w, q_b, k_b, v_b, weight_name, bias_name):
q_weight = self.model.get_initializer(q_w)
k_weight = self.model.get_initializer(k_w)
v_weight = self.model.get_initializer(v_w)
qw = np.transpose(NumpyHelper.to_array(q_weight), (1, 0))
kw = np.transpose(NumpyHelper.to_array(k_weight), (1, 0))
vw = np.transpose(NumpyHelper.to_array(v_weight), (1, 0))
qkv_weight = np.stack((qw, kw, vw), axis=1)
q_bias = self.model.get_initializer(q_b)
k_bias = self.model.get_initializer(k_b)
v_bias = self.model.get_initializer(v_b)
qb = NumpyHelper.to_array(q_bias)
kb = NumpyHelper.to_array(k_bias)
vb = NumpyHelper.to_array(v_bias)
qkv_bias = np.stack((qb, kb, vb), axis=0)
hidden_size = qkv_weight.shape[0]
weight = helper.make_tensor(
weight_name,
data_type=TensorProto.FLOAT,
dims=[hidden_size, hidden_size * 3],
vals=qkv_weight.flatten().tobytes(),
raw=True,
)
self.model.add_initializer(weight, self.this_graph_name)
bias = helper.make_tensor(
bias_name,
data_type=TensorProto.FLOAT,
dims=[hidden_size * 3],
vals=qkv_bias.flatten().tobytes(),
raw=True,
)
self.model.add_initializer(bias, self.this_graph_name)
self.add_fp32_value_info(weight.name)
self.add_fp32_value_info(bias.name)
return weight_name, bias_name
def fuse(
self,
node,
input_name_to_nodes,
output_name_to_node,
):
logger.info("Optimizing %s...", node.name)
logger.info(f"AttentionOpType: {self.attn_op_type}")
layer_id = self.get_layer_id(node)
i_hidden_states = node.input[0]
i_key_cache = self.get_input_by_name(node, "past_key")
i_value_cache = self.get_input_by_name(node, "past_value")
o_hidden_states = node.output[-1]
o_key_cache = self.get_output_by_name(node, "present_key")
o_value_cache = self.get_output_by_name(node, "present_value")
ln_weight = self.get_input_by_name(node, "input_layernorm.weight")
ln_bias = self.get_input_by_name(node, "input_layernorm.bias")
attn_q_weight, attn_q_bias, attn_k_weight, attn_k_bias, attn_v_weight, attn_v_bias = (
None,
None,
None,
None,
None,
None,
)
attn_qkv_weight, attn_qkv_bias = None, None
cos_cache, sin_cache = None, None
if self.attn_op_type != AttentionOpType.Attention:
attn_q_weight = self.process_initializer(
self.get_input_by_name(node, "self_attn.q_proj.weight"), ProcessGemmWFunc()
)
attn_k_weight = self.process_initializer(
self.get_input_by_name(node, "self_attn.k_proj.weight"), ProcessGemmWFunc()
)
attn_v_weight = self.process_initializer(
self.get_input_by_name(node, "self_attn.v_proj.weight"), ProcessGemmWFunc()
)
attn_q_bias = self.get_input_by_name(node, "self_attn.q_proj.bias")
attn_k_bias = self.get_input_by_name(node, "self_attn.k_proj.bias")
attn_v_bias = self.get_input_by_name(node, "self_attn.v_proj.bias")
cos_cache = self.process_initializer(
self.get_input_by_name(node, "rotary_emb.cos_cached"), ProcessRotCacheFunc()
)
sin_cache = self.process_initializer(
self.get_input_by_name(node, "rotary_emb.sin_cached"), ProcessRotCacheFunc()
)
else:
attn_qkv_weight, attn_qkv_bias = self.pack_qkv_gemm(
self.get_input_by_name(node, "self_attn.q_proj.weight"),
self.get_input_by_name(node, "self_attn.k_proj.weight"),
self.get_input_by_name(node, "self_attn.v_proj.weight"),
self.get_input_by_name(node, "self_attn.q_proj.bias"),
self.get_input_by_name(node, "self_attn.k_proj.bias"),
self.get_input_by_name(node, "self_attn.v_proj.bias"),
self.get_uname(layer_id, "attn_qkv_weight"),
self.get_uname(layer_id, "attn_qkv_bias"),
)
attn_out_weight = self.process_initializer(
self.get_input_by_name(node, "self_attn.dense.weight"), ProcessGemmWFunc()
)
attn_out_bias = self.get_input_by_name(node, "self_attn.dense.bias")
mlp_fc1_weight = self.process_initializer(self.get_input_by_name(node, "mlp.fc1.weight"), ProcessGemmWFunc())
mlp_fc2_weight = self.process_initializer(self.get_input_by_name(node, "mlp.fc2.weight"), ProcessGemmWFunc())
mlp_fc1_bias = self.get_input_by_name(node, "mlp.fc1.bias")
mlp_fc2_bias = self.get_input_by_name(node, "mlp.fc2.bias")
layer_known_edges_names = []
layer_known_edges_names.extend([i_hidden_states, i_key_cache, i_value_cache])
layer_known_edges_names.extend([o_hidden_states, o_key_cache, o_value_cache])
layer_known_edges_names.extend([ln_weight, ln_bias])
if self.attn_op_type != AttentionOpType.Attention:
layer_known_edges_names.extend(
[
attn_q_weight,
attn_q_bias,
attn_k_weight,
attn_k_bias,
attn_v_weight,
attn_v_bias,
cos_cache,
sin_cache,
]
)
else:
layer_known_edges_names.extend([attn_qkv_weight, attn_qkv_bias])
layer_known_edges_names.extend(
[attn_out_weight, attn_out_bias, mlp_fc1_weight, mlp_fc1_bias, mlp_fc2_weight, mlp_fc2_bias]
)
layer_known_edges_names.extend(
["attention_mask", "step", "seqlens_k", "total_sequence_length", "input_metadata", "position_ids"]
)
subgraph_nodes = []
subgraph_nodes.extend(self.layernorm([i_hidden_states, ln_weight, ln_bias], ["ln_out"]))
subgraph_nodes.extend(self.gemm(["attn_out", attn_out_weight, attn_out_bias], ["attn_add_out"], "OutProj_"))
subgraph_nodes.extend(self.gemm(["ln_out", mlp_fc1_weight, mlp_fc1_bias], ["fc1_out"], "FC1_"))
subgraph_nodes.extend(self.fastgelu(["fc1_out"], ["gelu_out"]))
subgraph_nodes.extend(self.gemm(["gelu_out", mlp_fc2_weight, mlp_fc2_bias], ["fc2_out"], "FC2_"))
subgraph_nodes.extend(self.add(["attn_add_out", "fc2_out"], ["residual_1_out"], "Residual_1"))
subgraph_nodes.extend(self.add([i_hidden_states, "residual_1_out"], [o_hidden_states], "Residual_2"))
if self.attn_op_type != AttentionOpType.Attention:
subgraph_nodes.extend(self.gemm(["ln_out", attn_q_weight, attn_q_bias], ["query"], "Q_"))
subgraph_nodes.extend(self.gemm(["ln_out", attn_k_weight, attn_k_bias], ["key"], "K_"))
subgraph_nodes.extend(self.gemm(["ln_out", attn_v_weight, attn_v_bias], ["value"], "V_"))
# vllm engine requires full position ids as the input
pos_ids_name = "position_ids" if self.attn_op_type == AttentionOpType.PagedAttention else "step"
subgraph_nodes.extend(self.rotary(["query", pos_ids_name, cos_cache, sin_cache], ["query_rot"], "Q_"))
subgraph_nodes.extend(self.rotary(["key", pos_ids_name, cos_cache, sin_cache], ["key_rot"], "K_"))
if self.attn_op_type == AttentionOpType.MultiHeadAttention:
subgraph_nodes.extend(
self.mha(
["query_rot", "key_rot", "value", "", "attention_mask", "", i_key_cache, i_value_cache],
["attn_out", o_key_cache, o_value_cache],
)
)
elif self.attn_op_type == AttentionOpType.GroupQueryAttention:
subgraph_nodes.extend(
self.gqa(
[
"query_rot",
"key_rot",
"value",
i_key_cache,
i_value_cache,
"seqlens_k",
"total_sequence_length",
],
["attn_out", o_key_cache, o_value_cache],
)
)
if layer_id == 0:
gqa_aux_nodes = self.get_gqa_aux_nodes()
for new_node in gqa_aux_nodes:
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
self.model.add_initializer(
numpy_helper.from_array(np.array([1], dtype="int64"), name="one"), self.this_graph_name
)
elif self.attn_op_type == AttentionOpType.PagedAttention:
subgraph_nodes.extend(
self.paged_attn(
["query_rot", "key_rot", "value", i_key_cache, i_value_cache, "input_metadata"],
["attn_out"],
)
)
else:
past_name = f"past_{layer_id}"
present_name = f"present_{layer_id}"
layer_known_edges_names.extend([past_name, present_name])
subgraph_nodes.extend(
self.attention(
["ln_out", attn_qkv_weight, attn_qkv_bias, "attention_mask", past_name], ["attn_out", present_name]
)
)
self.set_unique_name_and_add_nodes(subgraph_nodes, layer_id, layer_known_edges_names)
self.replace_fp32_value_info(i_hidden_states, ["batch_size", "seq_len", "hidden_size"])
self.replace_fp32_value_info(o_hidden_states, ["batch_size", "seq_len", "hidden_size"])
self.nodes_to_remove.append(node)
self.prune_graph = True
class PhiOnnxModel(OnnxModel):
def __init__(self, model: ModelProto, num_heads: int, hidden_size: int):
super().__init__(model)
self.phi2_preprocessor = Phi2PreProcessor(self.model, num_heads, hidden_size)
self.fission_transformer_block = FissionTransformerBlockPhi(self, num_heads)
self.fission_causal_lm_head = FissionTransformerCausalLMHeadPhi(self)
self.fission_transformer_layernorm = FissionTransformerLayerNormPhi(self)
self.fission_transformer_embedding = FissionTransformerEmbeddingPhi(self)
def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False):
assert options is not None
attn_op_type = options.attention_op_type
self.fission_transformer_block.set_attention_op_type(attn_op_type)
self.phi2_preprocessor.preprocess_onnx(attn_op_type)
self.fission_transformer_block.apply()
self.fission_transformer_layernorm.apply()
self.fission_causal_lm_head.apply()
self.fission_transformer_embedding.apply()
super().prune_graph()
# SLN ctor is placed here intentionally to delay the symbolic shape inference
self.fuse_sln = FusionSkipLayerNormalization(self)
self.fuse_bias_sln = FusionBiasSkipLayerNormalization(self)
self.fuse_sln.apply()
self.fuse_bias_sln.apply()
def get_fused_operator_statistics(self):
"""
Returns node count of fused operators.
"""
op_count = {}
ops = [
"Attention",
"MultiHeadAttention",
"GroupQueryAttention",
"PagedAttention",
"Gelu",
"BiasGelu",
"FastGelu",
"LayerNormalization",
"SkipLayerNormalization",
]
for op in ops:
nodes = self.get_nodes_by_op_type(op)
op_count[op] = len(nodes)
logger.info(f"Optimized operators: {op_count}")
return op_count
def is_fully_optimized(self, fused_op_count=None):
"""
Returns True when the model is fully optimized.
"""
if fused_op_count is None:
fused_op_count = self.get_fused_operator_statistics()
def op_count(op_name: str):
return fused_op_count.get(op_name) or 0
attention = (
op_count("Attention")
+ op_count("MultiHeadAttention")
+ op_count("GroupQueryAttention")
+ op_count("PagedAttention")
)
gelu = op_count("Gelu") + op_count("BiasGelu") + op_count("FastGelu")
layer_norm = op_count("LayerNormalization") + op_count("SkipLayerNormalization")
is_perfect = (attention > 0) and (attention == gelu) and (layer_norm >= attention)
if layer_norm == 0:
logger.debug("Layer Normalization not fused")
if gelu == 0:
logger.debug("Gelu (or FastGelu) not fused")
if attention == 0:
logger.warning("Attention (or MultiHeadAttention) not fused")
return is_perfect