# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger from typing import Optional from fusion_attention_vae import FusionAttentionVae from fusion_options import FusionOptions from onnx import ModelProto from onnx_model_unet import UnetOnnxModel logger = getLogger(__name__) class VaeOnnxModel(UnetOnnxModel): def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0): assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0) super().__init__(model, num_heads=num_heads, hidden_size=hidden_size) def fuse_multi_head_attention(self, options: Optional[FusionOptions] = None): # Self Attention self_attention_fusion = FusionAttentionVae(self, self.hidden_size, self.num_heads) self_attention_fusion.apply() def get_fused_operator_statistics(self): """ Returns node count of fused operators. """ op_count = {} ops = [ "Attention", "GroupNorm", "SkipGroupNorm", "NhwcConv", ] for op in ops: nodes = self.get_nodes_by_op_type(op) op_count[op] = len(nodes) logger.info(f"Optimized operators:{op_count}") return op_count
Memory