# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger from fusion_attention_clip import FusionAttentionClip from onnx import ModelProto from onnx_model_bert import BertOnnxModel logger = getLogger(__name__) class ClipOnnxModel(BertOnnxModel): def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0): super().__init__(model, num_heads=num_heads, hidden_size=hidden_size) self.clip_attention_fusion = FusionAttentionClip(self, self.hidden_size, self.num_heads) def get_fused_operator_statistics(self): """ Returns node count of fused operators. """ op_count = {} ops = [ "Attention", "Gelu", "LayerNormalization", "QuickGelu", "SkipLayerNormalization", ] for op in ops: nodes = self.get_nodes_by_op_type(op) op_count[op] = len(nodes) logger.info(f"Optimized operators:{op_count}") return op_count def fuse_attention(self): self.clip_attention_fusion.apply()
Memory