# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- import logging from typing import Optional from fusion_attention_unet import FusionAttentionUnet from fusion_bias_add import FusionBiasAdd from fusion_biassplitgelu import FusionBiasSplitGelu from fusion_group_norm import FusionGroupNorm from fusion_nhwc_conv import FusionNhwcConv from fusion_options import FusionOptions from fusion_skip_group_norm import FusionSkipGroupNorm from fusion_transpose import FusionInsertTranspose, FusionTranspose from import_utils import is_installed from onnx import ModelProto from onnx_model import OnnxModel from onnx_model_bert import BertOnnxModel logger = logging.getLogger(__name__) class UnetOnnxModel(BertOnnxModel): def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0): """Initialize UNet ONNX Model. Args: model (ModelProto): the ONNX model num_heads (int, optional): number of attention heads. Defaults to 0 (detect the parameter automatically). hidden_size (int, optional): hidden dimension. Defaults to 0 (detect the parameter automatically). """ assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0) super().__init__(model, num_heads=num_heads, hidden_size=hidden_size) def preprocess(self): self.remove_useless_div() def postprocess(self): self.prune_graph() self.remove_unused_constant() def remove_useless_div(self): """Remove Div by 1""" div_nodes = [node for node in self.nodes() if node.op_type == "Div"] nodes_to_remove = [] for div in div_nodes: if self.find_constant_input(div, 1.0) == 1: nodes_to_remove.append(div) for node in nodes_to_remove: self.replace_input_of_all_nodes(node.output[0], node.input[0]) if nodes_to_remove: self.remove_nodes(nodes_to_remove) logger.info("Removed %d Div nodes", len(nodes_to_remove)) def convert_conv_to_nhwc(self): # Transpose weights in offline might help since ORT does not apply constant-folding on Transpose nodes. conv_to_nhwc_conv = FusionNhwcConv(self, update_weight=True) conv_to_nhwc_conv.apply() def merge_adjacent_transpose(self): fusion_transpose = FusionTranspose(self) fusion_transpose.apply() remove_count = 0 nodes = self.get_nodes_by_op_type("Transpose") for node in nodes: permutation = OnnxModel.get_node_attribute(node, "perm") assert isinstance(permutation, list) if permutation != list(range(len(permutation))): continue assert not ( self.find_graph_output(node.output[0]) or self.find_graph_input(node.input[0]) or self.find_graph_output(node.input[0]) ) # Let all children nodes skip current Transpose node and link to its parent # Note that we cannot update parent node output since parent node might have more than one children. self.replace_input_of_all_nodes(node.output[0], node.input[0]) self.remove_node(node) remove_count += 1 total = len(fusion_transpose.nodes_to_remove) + remove_count if total: logger.info("Removed %d Transpose nodes", total) def fuse_multi_head_attention(self, options: Optional[FusionOptions] = None): # Self Attention enable_packed_qkv = (options is None) or options.enable_packed_qkv self_attention_fusion = FusionAttentionUnet( self, self.hidden_size, self.num_heads, is_cross_attention=False, enable_packed_qkv=enable_packed_qkv, enable_packed_kv=False, ) self_attention_fusion.apply() # Cross Attention enable_packed_kv = (options is None) or options.enable_packed_kv cross_attention_fusion = FusionAttentionUnet( self, self.hidden_size, self.num_heads, is_cross_attention=True, enable_packed_qkv=False, enable_packed_kv=enable_packed_kv, ) cross_attention_fusion.apply() def fuse_bias_add(self): fusion = FusionBiasAdd(self) fusion.apply() def optimize(self, options: Optional[FusionOptions] = None): if is_installed("tqdm"): import tqdm from tqdm.contrib.logging import logging_redirect_tqdm with logging_redirect_tqdm(): steps = 18 progress_bar = tqdm.tqdm(range(steps), initial=0, desc="fusion") self._optimize(options, progress_bar) else: logger.info("tqdm is not installed. Run optimization without progress bar") self._optimize(options, None) def _optimize(self, options: Optional[FusionOptions] = None, progress_bar=None): if (options is not None) and not options.enable_shape_inference: self.disable_shape_inference() self.utils.remove_identity_nodes() if progress_bar: progress_bar.update(1) # Remove cast nodes that having same data type of input and output based on symbolic shape inference. self.utils.remove_useless_cast_nodes() if progress_bar: progress_bar.update(1) if (options is None) or options.enable_layer_norm: self.fuse_layer_norm() if progress_bar: progress_bar.update(1) if (options is None) or options.enable_gelu: self.fuse_gelu() if progress_bar: progress_bar.update(1) self.preprocess() if progress_bar: progress_bar.update(1) self.fuse_reshape() if progress_bar: progress_bar.update(1) if (options is None) or options.enable_group_norm: channels_last = (options is None) or options.group_norm_channels_last group_norm_fusion = FusionGroupNorm(self, channels_last) group_norm_fusion.apply() insert_transpose_fusion = FusionInsertTranspose(self) insert_transpose_fusion.apply() if progress_bar: progress_bar.update(1) if (options is None) or options.enable_bias_splitgelu: bias_split_gelu_fusion = FusionBiasSplitGelu(self) bias_split_gelu_fusion.apply() if progress_bar: progress_bar.update(1) if (options is None) or options.enable_attention: # self.save_model_to_file("before_mha.onnx") self.fuse_multi_head_attention(options) if progress_bar: progress_bar.update(1) if (options is None) or options.enable_skip_layer_norm: self.fuse_skip_layer_norm() if progress_bar: progress_bar.update(1) self.fuse_shape() if progress_bar: progress_bar.update(1) # Remove reshape nodes that having same shape of input and output based on symbolic shape inference. self.utils.remove_useless_reshape_nodes() if progress_bar: progress_bar.update(1) if (options is None) or options.enable_skip_group_norm: skip_group_norm_fusion = FusionSkipGroupNorm(self) skip_group_norm_fusion.apply() if progress_bar: progress_bar.update(1) if (options is None) or options.enable_bias_skip_layer_norm: # Fuse SkipLayerNormalization and Add Bias before it. self.fuse_add_bias_skip_layer_norm() if progress_bar: progress_bar.update(1) if options is not None and options.enable_gelu_approximation: self.gelu_approximation() if progress_bar: progress_bar.update(1) if options is None or options.enable_nhwc_conv: self.convert_conv_to_nhwc() self.merge_adjacent_transpose() if progress_bar: progress_bar.update(1) if options is not None and options.enable_bias_add: self.fuse_bias_add() if progress_bar: progress_bar.update(1) self.postprocess() if progress_bar: progress_bar.update(1) logger.info(f"opset version: {self.get_opset_version()}") def get_fused_operator_statistics(self): """ Returns node count of fused operators. """ op_count = {} ops = [ "Attention", "MultiHeadAttention", "LayerNormalization", "SkipLayerNormalization", "BiasSplitGelu", "GroupNorm", "SkipGroupNorm", "NhwcConv", "BiasAdd", ] for op in ops: nodes = self.get_nodes_by_op_type(op) op_count[op] = len(nodes) logger.info(f"Optimized operators:{op_count}") return op_count
Memory