# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import Tuple, Union
import numpy as np
from fusion_base import Fusion
from onnx import NodeProto, TensorProto, helper, numpy_helper
from onnx_model import OnnxModel
logger = getLogger(__name__)
class FusionAttentionVae(Fusion):
"""
Fuse Attention subgraph of Vae Decoder into one Attention node.
"""
def __init__(self, model: OnnxModel, hidden_size: int, num_heads: int):
super().__init__(model, "Attention", ["Softmax"])
self.hidden_size = hidden_size
self.num_heads = num_heads
# Flags to show warning only once
self.num_heads_warning = True
self.hidden_size_warning = True
def get_num_heads_and_hidden_size(self, reshape_q: NodeProto, add_q: NodeProto) -> Tuple[int, int]:
"""Detect num_heads and hidden_size from a reshape node.
Args:
reshape_q (NodeProto): reshape node for Q
add_q (NodeProto): add node for Q
Returns:
Tuple[int, int]: num_heads and hidden_size
"""
concat = self.model.get_parent(reshape_q, 1)
if concat is None or len(concat.input) != 4:
return self.num_heads, self.hidden_size # Fall back to user specified value
value = self.model.get_constant_value(concat.input[2])
if not (value is not None and isinstance(value, np.ndarray) and value.size == 1):
return self.num_heads, self.hidden_size # Fall back to user specified value
num_heads = int(value)
if num_heads <= 0:
return self.num_heads, self.hidden_size # Fall back to user specified value
_, bias = self.model.get_constant_input(add_q)
if (bias is None) or (not isinstance(bias, np.ndarray)) or bias.ndim != 1:
return self.num_heads, self.hidden_size # Fall back to user specified value
hidden_size = bias.shape[0]
if self.num_heads > 0 and num_heads != self.num_heads:
if self.num_heads_warning:
logger.warning(
"Detected number of attention heads is %d. Ignore --num_heads %d", num_heads, self.num_heads
)
self.num_heads_warning = False # Do not show the warning more than once
if self.hidden_size > 0 and hidden_size != self.hidden_size:
if self.hidden_size_warning:
logger.warning("Detected hidden size is %d. Ignore --hidden_size %d", hidden_size, self.hidden_size)
self.hidden_size_warning = False # Do not show the warning more than once
return num_heads, hidden_size
def create_attention_node(
self,
q_matmul: NodeProto,
q_add: NodeProto,
k_matmul: NodeProto,
k_add: NodeProto,
v_matmul: NodeProto,
v_add: NodeProto,
num_heads: int,
hidden_size: int,
input_name: str,
output_name: str,
) -> Union[NodeProto, None]:
"""Create an Attention node.
Args:
q_matmul (NodeProto): MatMul node in fully connection for Q
q_add (NodeProto): Add bias node in fully connection for Q
k_matmul (NodeProto): MatMul node in fully connection for K
k_add (NodeProto): Add bias node in fully connection for K
v_matmul (NodeProto): MatMul node in fully connection for V
v_add (NodeProto): Add bias node in fully connection for V
num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
input_name (str): input name
output_name (str): output name
Returns:
Union[NodeProto, None]: the node created or None if failed.
"""
if q_matmul.input[0] != input_name or k_matmul.input[0] != input_name or v_matmul.input[0] != input_name:
logger.debug(
"For self attention, input hidden state for q and k/v shall be same. Got %s, %s, %s",
q_matmul.input[0],
k_matmul.input[0],
v_matmul.input[0],
)
return None
if hidden_size > 0 and (hidden_size % num_heads) != 0:
logger.debug("input hidden size %d is not a multiple of num of heads %d", hidden_size, num_heads)
return None
q_weight_tensor = self.model.get_initializer(q_matmul.input[1])
k_weight_tensor = self.model.get_initializer(k_matmul.input[1])
v_weight_tensor = self.model.get_initializer(v_matmul.input[1])
if not (q_weight_tensor and k_weight_tensor and v_weight_tensor):
return None
q_bias_tensor = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0])
k_bias_tensor = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0])
v_bias_tensor = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0])
q_bias = numpy_helper.to_array(q_bias_tensor)
k_bias = numpy_helper.to_array(k_bias_tensor)
v_bias = numpy_helper.to_array(v_bias_tensor)
q_bias_shape = np.prod(q_bias.shape)
k_bias_shape = np.prod(k_bias.shape)
v_bias_shape = np.prod(v_bias.shape)
# Sometimes weights are stored in fp16
if q_weight_tensor.data_type == 10:
logger.debug("weights are in fp16. Please run fp16 conversion after optimization")
return None
q_weight = numpy_helper.to_array(q_weight_tensor)
k_weight = numpy_helper.to_array(k_weight_tensor)
v_weight = numpy_helper.to_array(v_weight_tensor)
# assert q and k have same shape as expected
if q_weight.shape != k_weight.shape or q_weight.shape != v_weight.shape:
return None
qw_in_size = q_weight.shape[0]
kw_in_size = k_weight.shape[0]
vw_in_size = v_weight.shape[0]
assert qw_in_size == kw_in_size and kw_in_size == vw_in_size
if hidden_size > 0 and hidden_size != qw_in_size:
raise ValueError(
f"Input hidden size ({hidden_size}) is not same as weight dimension of q,k,v ({qw_in_size}). "
"Please provide a correct input hidden size or pass in 0"
)
# All the matrices can have the same shape or q, k matrics can have the same shape with v being different
# For 2d weights, the shapes would be [in_size, out_size].
# For 3d weights, shape would be [in_size, a, b] where a*b = out_size
qw_out_size = np.prod(q_weight.shape[1:])
qkv_weight = np.stack((q_weight, k_weight, v_weight), axis=1)
qkv_weight_dim = 3 * int(qw_out_size)
attention_node_name = self.model.create_node_name("Attention")
assert q_bias_shape == k_bias_shape == v_bias_shape
qkv_bias_dim = 0
qkv_bias = np.stack((q_bias, k_bias, v_bias), axis=0)
qkv_bias_dim = 3 * q_bias_shape
self.add_initializer(
name=attention_node_name + "_qkv_weight",
data_type=TensorProto.FLOAT,
dims=[qw_in_size, qkv_weight_dim],
vals=qkv_weight,
)
# No bias, use zeros
qkv_bias = np.zeros([3, hidden_size], dtype=np.float32)
qkv_bias_dim = 3 * hidden_size
self.add_initializer(
name=attention_node_name + "_qkv_bias",
data_type=TensorProto.FLOAT,
dims=[qkv_bias_dim],
vals=qkv_bias,
)
attention_inputs = [
input_name,
attention_node_name + "_qkv_weight",
attention_node_name + "_qkv_bias",
]
attention_node = helper.make_node(
"Attention",
inputs=attention_inputs,
outputs=[output_name],
name=attention_node_name,
)
attention_node.domain = "com.microsoft"
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
self.increase_counter("Attention (self attention)")
return attention_node
def fuse(self, softmax_node, input_name_to_nodes, output_name_to_node):
matmul_qkv = self.model.find_first_child_by_type(softmax_node, "MatMul", input_name_to_nodes, recursive=False)
if matmul_qkv is None:
return
reshape_qkv = self.model.find_first_child_by_type(matmul_qkv, "Reshape", input_name_to_nodes, recursive=False)
if reshape_qkv is None:
return
transpose_qkv = self.model.find_first_child_by_type(
reshape_qkv, "Transpose", input_name_to_nodes, recursive=False
)
if transpose_qkv is None:
return
reshape_out = self.model.find_first_child_by_type(
transpose_qkv, "Reshape", input_name_to_nodes, recursive=False
)
if reshape_out is None:
return
matmul_out = self.model.find_first_child_by_type(reshape_out, "MatMul", input_name_to_nodes, recursive=False)
if matmul_out is None:
return
add_out = self.model.find_first_child_by_type(matmul_out, "Add", input_name_to_nodes, recursive=False)
if add_out is None:
return
transpose_out = self.model.find_first_child_by_type(add_out, "Transpose", input_name_to_nodes, recursive=False)
if transpose_out is None:
return
v_nodes = self.model.match_parent_path(
matmul_qkv, ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, None]
)
if v_nodes is None:
logger.debug("fuse_attention: failed to match v path")
return
(_, _, _, add_v, matmul_v) = v_nodes
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
if qk_nodes is not None:
(_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes
else:
logger.debug("fuse_attention: failed to match qk path")
return
q_nodes = self.model.match_parent_path(
matmul_qk, ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, 0, None]
)
if q_nodes is None:
logger.debug("fuse_attention: failed to match q path")
return
(_, _transpose_q, reshape_q, add_q, matmul_q) = q_nodes
k_nodes = self.model.match_parent_path(
matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, 0, None]
)
if k_nodes is None:
logger.debug("fuse_attention: failed to match k path")
return
(_, _, _, _, add_k, matmul_k) = k_nodes
attention_last_node = reshape_out
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, add_q)
if q_num_heads <= 0:
logger.debug("fuse_attention: failed to detect num_heads")
return
# number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
new_node = self.create_attention_node(
matmul_q,
add_q,
matmul_k,
add_k,
matmul_v,
add_v,
q_num_heads,
q_hidden_size,
matmul_q.input[0],
attention_last_node.output[0],
)
if new_node is None:
return
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
self.nodes_to_remove.extend([attention_last_node, transpose_qkv])
# Use prune graph to remove nodes since they are shared by all attention nodes.
self.prune_graph = True