# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger from typing import Tuple, Union import numpy as np from fusion_base import Fusion from onnx import NodeProto, TensorProto, helper, numpy_helper from onnx_model import OnnxModel logger = getLogger(__name__) class FusionAttentionVae(Fusion): """ Fuse Attention subgraph of Vae Decoder into one Attention node. """ def __init__(self, model: OnnxModel, hidden_size: int, num_heads: int): super().__init__(model, "Attention", ["Softmax"]) self.hidden_size = hidden_size self.num_heads = num_heads # Flags to show warning only once self.num_heads_warning = True self.hidden_size_warning = True def get_num_heads_and_hidden_size(self, reshape_q: NodeProto, add_q: NodeProto) -> Tuple[int, int]: """Detect num_heads and hidden_size from a reshape node. Args: reshape_q (NodeProto): reshape node for Q add_q (NodeProto): add node for Q Returns: Tuple[int, int]: num_heads and hidden_size """ concat = self.model.get_parent(reshape_q, 1) if concat is None or len(concat.input) != 4: return self.num_heads, self.hidden_size # Fall back to user specified value value = self.model.get_constant_value(concat.input[2]) if not (value is not None and isinstance(value, np.ndarray) and value.size == 1): return self.num_heads, self.hidden_size # Fall back to user specified value num_heads = int(value) if num_heads <= 0: return self.num_heads, self.hidden_size # Fall back to user specified value _, bias = self.model.get_constant_input(add_q) if (bias is None) or (not isinstance(bias, np.ndarray)) or bias.ndim != 1: return self.num_heads, self.hidden_size # Fall back to user specified value hidden_size = bias.shape[0] if self.num_heads > 0 and num_heads != self.num_heads: if self.num_heads_warning: logger.warning( "Detected number of attention heads is %d. Ignore --num_heads %d", num_heads, self.num_heads ) self.num_heads_warning = False # Do not show the warning more than once if self.hidden_size > 0 and hidden_size != self.hidden_size: if self.hidden_size_warning: logger.warning("Detected hidden size is %d. Ignore --hidden_size %d", hidden_size, self.hidden_size) self.hidden_size_warning = False # Do not show the warning more than once return num_heads, hidden_size def create_attention_node( self, q_matmul: NodeProto, q_add: NodeProto, k_matmul: NodeProto, k_add: NodeProto, v_matmul: NodeProto, v_add: NodeProto, num_heads: int, hidden_size: int, input_name: str, output_name: str, ) -> Union[NodeProto, None]: """Create an Attention node. Args: q_matmul (NodeProto): MatMul node in fully connection for Q q_add (NodeProto): Add bias node in fully connection for Q k_matmul (NodeProto): MatMul node in fully connection for K k_add (NodeProto): Add bias node in fully connection for K v_matmul (NodeProto): MatMul node in fully connection for V v_add (NodeProto): Add bias node in fully connection for V num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning. hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning. input_name (str): input name output_name (str): output name Returns: Union[NodeProto, None]: the node created or None if failed. """ if q_matmul.input[0] != input_name or k_matmul.input[0] != input_name or v_matmul.input[0] != input_name: logger.debug( "For self attention, input hidden state for q and k/v shall be same. Got %s, %s, %s", q_matmul.input[0], k_matmul.input[0], v_matmul.input[0], ) return None if hidden_size > 0 and (hidden_size % num_heads) != 0: logger.debug("input hidden size %d is not a multiple of num of heads %d", hidden_size, num_heads) return None q_weight_tensor = self.model.get_initializer(q_matmul.input[1]) k_weight_tensor = self.model.get_initializer(k_matmul.input[1]) v_weight_tensor = self.model.get_initializer(v_matmul.input[1]) if not (q_weight_tensor and k_weight_tensor and v_weight_tensor): return None q_bias_tensor = self.model.get_initializer(q_add.input[1]) or self.model.get_initializer(q_add.input[0]) k_bias_tensor = self.model.get_initializer(k_add.input[1]) or self.model.get_initializer(k_add.input[0]) v_bias_tensor = self.model.get_initializer(v_add.input[1]) or self.model.get_initializer(v_add.input[0]) q_bias = numpy_helper.to_array(q_bias_tensor) k_bias = numpy_helper.to_array(k_bias_tensor) v_bias = numpy_helper.to_array(v_bias_tensor) q_bias_shape = np.prod(q_bias.shape) k_bias_shape = np.prod(k_bias.shape) v_bias_shape = np.prod(v_bias.shape) # Sometimes weights are stored in fp16 if q_weight_tensor.data_type == 10: logger.debug("weights are in fp16. Please run fp16 conversion after optimization") return None q_weight = numpy_helper.to_array(q_weight_tensor) k_weight = numpy_helper.to_array(k_weight_tensor) v_weight = numpy_helper.to_array(v_weight_tensor) # assert q and k have same shape as expected if q_weight.shape != k_weight.shape or q_weight.shape != v_weight.shape: return None qw_in_size = q_weight.shape[0] kw_in_size = k_weight.shape[0] vw_in_size = v_weight.shape[0] assert qw_in_size == kw_in_size and kw_in_size == vw_in_size if hidden_size > 0 and hidden_size != qw_in_size: raise ValueError( f"Input hidden size ({hidden_size}) is not same as weight dimension of q,k,v ({qw_in_size}). " "Please provide a correct input hidden size or pass in 0" ) # All the matrices can have the same shape or q, k matrics can have the same shape with v being different # For 2d weights, the shapes would be [in_size, out_size]. # For 3d weights, shape would be [in_size, a, b] where a*b = out_size qw_out_size = np.prod(q_weight.shape[1:]) qkv_weight = np.stack((q_weight, k_weight, v_weight), axis=1) qkv_weight_dim = 3 * int(qw_out_size) attention_node_name = self.model.create_node_name("Attention") assert q_bias_shape == k_bias_shape == v_bias_shape qkv_bias_dim = 0 qkv_bias = np.stack((q_bias, k_bias, v_bias), axis=0) qkv_bias_dim = 3 * q_bias_shape self.add_initializer( name=attention_node_name + "_qkv_weight", data_type=TensorProto.FLOAT, dims=[qw_in_size, qkv_weight_dim], vals=qkv_weight, ) # No bias, use zeros qkv_bias = np.zeros([3, hidden_size], dtype=np.float32) qkv_bias_dim = 3 * hidden_size self.add_initializer( name=attention_node_name + "_qkv_bias", data_type=TensorProto.FLOAT, dims=[qkv_bias_dim], vals=qkv_bias, ) attention_inputs = [ input_name, attention_node_name + "_qkv_weight", attention_node_name + "_qkv_bias", ] attention_node = helper.make_node( "Attention", inputs=attention_inputs, outputs=[output_name], name=attention_node_name, ) attention_node.domain = "com.microsoft" attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) self.increase_counter("Attention (self attention)") return attention_node def fuse(self, softmax_node, input_name_to_nodes, output_name_to_node): matmul_qkv = self.model.find_first_child_by_type(softmax_node, "MatMul", input_name_to_nodes, recursive=False) if matmul_qkv is None: return reshape_qkv = self.model.find_first_child_by_type(matmul_qkv, "Reshape", input_name_to_nodes, recursive=False) if reshape_qkv is None: return transpose_qkv = self.model.find_first_child_by_type( reshape_qkv, "Transpose", input_name_to_nodes, recursive=False ) if transpose_qkv is None: return reshape_out = self.model.find_first_child_by_type( transpose_qkv, "Reshape", input_name_to_nodes, recursive=False ) if reshape_out is None: return matmul_out = self.model.find_first_child_by_type(reshape_out, "MatMul", input_name_to_nodes, recursive=False) if matmul_out is None: return add_out = self.model.find_first_child_by_type(matmul_out, "Add", input_name_to_nodes, recursive=False) if add_out is None: return transpose_out = self.model.find_first_child_by_type(add_out, "Transpose", input_name_to_nodes, recursive=False) if transpose_out is None: return v_nodes = self.model.match_parent_path( matmul_qkv, ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, None] ) if v_nodes is None: logger.debug("fuse_attention: failed to match v path") return (_, _, _, add_v, matmul_v) = v_nodes qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0]) if qk_nodes is not None: (_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes else: logger.debug("fuse_attention: failed to match qk path") return q_nodes = self.model.match_parent_path( matmul_qk, ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, 0, None] ) if q_nodes is None: logger.debug("fuse_attention: failed to match q path") return (_, _transpose_q, reshape_q, add_q, matmul_q) = q_nodes k_nodes = self.model.match_parent_path( matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, 0, None] ) if k_nodes is None: logger.debug("fuse_attention: failed to match k path") return (_, _, _, _, add_k, matmul_k) = k_nodes attention_last_node = reshape_out q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, add_q) if q_num_heads <= 0: logger.debug("fuse_attention: failed to detect num_heads") return # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads new_node = self.create_attention_node( matmul_q, add_q, matmul_k, add_k, matmul_v, add_v, q_num_heads, q_hidden_size, matmul_q.input[0], attention_last_node.output[0], ) if new_node is None: return self.nodes_to_add.append(new_node) self.node_name_to_graph_name[new_node.name] = self.this_graph_name self.nodes_to_remove.extend([attention_last_node, transpose_qkv]) # Use prune graph to remove nodes since they are shared by all attention nodes. self.prune_graph = True
Memory