# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- from __future__ import annotations import json from collections.abc import MutableMapping from dataclasses import dataclass from typing import Any import onnx from .quant_utils import QuantType @dataclass class QuantTypeInfo: """ The quantization type information for a tensor override. """ quant_type: QuantType symmetric: bool | None = None # If None, assumes default is used. reduce_range: bool | None = None # If None, assumes default is used. axis: int | None = None # If None, assumes per-tensor quantization def __eq__(self, other: object): if isinstance(other, QuantTypeInfo): return ( self.quant_type == other.quant_type and (self.symmetric is None or other.symmetric is None or self.symmetric == other.symmetric) and (self.reduce_range is None or other.reduce_range is None or self.reduce_range == other.reduce_range) and (self.axis == other.axis) ) return NotImplemented @staticmethod def load_from_dict( raw_dict: dict[str, Any], default_qtype: QuantType | None = None, default_symmetric: bool | None = None, default_reduce_range: bool | None = None, ) -> QuantTypeInfo: return QuantTypeInfo( raw_dict.get("quant_type", default_qtype), raw_dict.get("symmetric", default_symmetric), raw_dict.get("reduce_range", default_reduce_range), raw_dict.get("axis"), ) def save_to_dict(self, raw_dict: dict[str, Any]): raw_dict["quant_type"] = self.quant_type if self.symmetric is not None: raw_dict["symmetric"] = self.symmetric if self.reduce_range is not None: raw_dict["reduce_range"] = self.reduce_range if self.axis is not None: raw_dict["axis"] = self.axis class TensorQuantOverridesHelper(MutableMapping): """ Utility wrapper over the tensor quantization overrides passed via extra_options. """ def __init__(self, raw_overrides: dict[str, list[dict[str, Any]]]): self.overrides = raw_overrides self.quant_types = None self.keys_unsupported_with_scale_zp = {"symmetric", "reduce_range", "rmax", "rmin"} def has_per_tensor_overrides(self, tensor_name: str) -> bool: overrides_list = self.overrides.get(tensor_name) return overrides_list and "axis" not in overrides_list[0] def has_per_channel_overrides(self, tensor_name: str) -> bool: overrides_list = self.overrides.get(tensor_name) return overrides_list and "axis" in overrides_list[0] def overrides_scale_zp(self, tensor_name: str) -> bool: overrides_list = self.overrides.get(tensor_name) return overrides_list and ("scale" in overrides_list[0]) and ("zero_point" in overrides_list[0]) def get_per_tensor_overrides( self, tensor_name: str, default_val: dict[str, Any] | None = None, ) -> dict[str, Any] | None: default_list_val = [default_val] if default_val is not None else None overrides_list = self.overrides.get(tensor_name, default_list_val) if overrides_list and "axis" in overrides_list[0]: raise ValueError( f"Expected tensor '{tensor_name}' to use per-tensor quantization overrides, " f"but found per-channel overrides." ) return overrides_list[0] if overrides_list else None def get_per_channel_overrides( self, tensor_name: str, default_val: list[dict[str, Any]] | None = None, ) -> list[dict[str, Any]] | None: overrides_list = self.overrides.get(tensor_name, default_val) if not overrides_list: return None if "axis" not in overrides_list[0]: raise ValueError( f"Expected tensor '{tensor_name}' to have per-channel quantization overrides (axis value is missing).", ) return overrides_list def get_quant_types(self) -> set[QuantType]: if self.quant_types is not None: return self.quant_types self.quant_types = set() if self.overrides: for quant_overrides_list in self.overrides.values(): for quant_overrides in quant_overrides_list: if "quant_type" in quant_overrides: self.quant_types.add(quant_overrides["quant_type"]) if "convert" in quant_overrides and "quant_type" in quant_overrides["convert"]: self.quant_types.add(quant_overrides["convert"]["quant_type"]) return self.quant_types def _is_valid_per_tensor( self, initializers, default_activation_qtype, tensor_name: str, quant_overrides: dict[str, Any], ) -> tuple[bool, str | None]: if not isinstance(quant_overrides, dict): return ( False, f"Tensor quantization overrides for '{tensor_name}' are not in a dict", ) is_initializer = tensor_name in initializers quant_type = quant_overrides.get("quant_type") if quant_type: self.quant_types.add(quant_type) has_scale = "scale" in quant_overrides has_zero_point = "zero_point" in quant_overrides if (has_scale and not has_zero_point) or (has_zero_point and not has_scale): return ( False, "Must provide both 'scale' and 'zero_point' if one of the overrides is provided", ) if has_scale: keys = self.keys_unsupported_with_scale_zp.intersection(set(quant_overrides)) if keys: return ( False, f"Tensor override option(s) [{', '.join(keys)}] are invalid with 'scale' and 'zero_point'", ) if "reduce_range" in quant_overrides and not is_initializer: return ( False, f"Option 'reduce_range' is only supported for initializers, not for activation {tensor_name}", ) if "convert" in quant_overrides: if is_initializer: return False, "Cannot use 'convert' override for initializers" if "quant_type" not in quant_overrides["convert"]: return False, f"'convert' options (tensor '{tensor_name}') must specify a 'quant_type'" if "reduce_range" in quant_overrides["convert"]: return ( False, f"Option 'reduce_range' is only supported for initializers, not for activation {tensor_name}", ) convert_quant_type = quant_overrides["convert"]["quant_type"] original_quant_type = quant_type if quant_type is not None else default_activation_qtype if convert_quant_type == original_quant_type: return ( False, f"'convert' quant_type must differ from original quant_type (tensor '{tensor_name}')", ) convert_has_scale = "scale" in quant_overrides["convert"] convert_has_zero_point = "zero_point" in quant_overrides["convert"] if (convert_has_scale and not convert_has_zero_point) or (convert_has_zero_point and not convert_has_scale): return ( False, f"Must provide both 'scale' and 'zero_point' if one of the overrides is provided (tensor '{tensor_name}')", ) if convert_has_scale: keys = self.keys_unsupported_with_scale_zp.intersection(set(quant_overrides["convert"])) if keys: return ( False, f"Tensor override option(s) [{', '.join(keys)}] are invalid with 'scale' and 'zero_point' " f"(tensor '{tensor_name}')", ) self.quant_types.add(convert_quant_type) return True, None def _is_valid_per_channel( self, initializers, tensor_name: str, quant_overrides_list: list[dict[str, Any]], ) -> tuple[bool, str | None]: is_initializer = tensor_name in initializers if not is_initializer: return ( False, f"Tensor '{tensor_name}' has per-channel overrides, but is not an initializer", ) axis = quant_overrides_list[0].get("axis") if axis is None: return ( False, f"Per-channel overrides for tensor {tensor_name} is missing an 'axis' value in " "the first channel dictionary.", ) weight_shape = list(initializers[tensor_name].dims) weight_rank = len(weight_shape) norm_axis = axis if norm_axis < 0: norm_axis += weight_rank if norm_axis < 0 or norm_axis >= len(weight_shape): return ( False, f"Axis override value is out-of-bounds for tensor {tensor_name} (rank {len(weight_shape)})", ) if len(quant_overrides_list) > 1 and len(quant_overrides_list) != weight_shape[norm_axis]: return ( False, f"Incorrect number of channel overrides for tensor {tensor_name} (axis {axis}), " f"expected {weight_shape[axis]}, but found {len(quant_overrides_list)}.", ) if "convert" in quant_overrides_list[0]: return False, f"Cannot use 'convert' override for initializers, such as {tensor_name}." quant_type = quant_overrides_list[0].get("quant_type") if quant_type: self.quant_types.add(quant_type) symmetric = quant_overrides_list[0].get("symmetric") reduce_range = quant_overrides_list[0].get("reduce_range") has_scale = "scale" in quant_overrides_list[0] has_zero_point = "zero_point" in quant_overrides_list[0] has_scale_zp = has_scale and has_zero_point if (has_scale and not has_zero_point) or (has_zero_point and not has_scale): return ( False, "Must provide both 'scale' and 'zero_point' if one of the overrides is provided", ) if has_scale_zp: keys = self.keys_unsupported_with_scale_zp.intersection(set(quant_overrides_list[0])) if keys: return ( False, f"Tensor override option(s) [{', '.join(keys)}] are invalid with 'scale' and 'zero_point'", ) has_rmin = "rmin" in quant_overrides_list[0] has_rmax = "rmax" in quant_overrides_list[0] has_rmin_rmax = has_rmin and has_rmax if (has_rmin and not has_rmax) or (not has_rmin and has_rmax): return ( False, "Must provide both 'rmin' and 'rmax' if one is provided", ) for index, quant_overrides in enumerate(quant_overrides_list[1:]): if not isinstance(quant_overrides, dict): return ( False, f"Tensor quantization overrides at index {index} for '{tensor_name}' are not in a dict", ) if "convert" in quant_overrides: return False, f"Cannot use 'convert' override for initializers, such as {tensor_name}." # For per-channel quantization, all channels must use the same quantization type, axis, symmetric # and reduce_range values. And, if specified, they must be present in the first channel dict # (i.e., quant_overrides_list[0]). if "quant_type" in quant_overrides and quant_type != quant_overrides["quant_type"]: return ( False, "Channel quantization types for tensor '{tensor_name}' do not match at index {index}.", ) if "axis" in quant_overrides and axis != quant_overrides["axis"] and norm_axis != quant_overrides["axis"]: return ( False, "Channel axis for tensor '{tensor_name}' does not match at index {index}.", ) if "symmetric" in quant_overrides and symmetric != quant_overrides["symmetric"]: return ( False, "Channel symmetric value for tensor '{tensor_name}' does not match at index {index}.", ) if "reduce_range" in quant_overrides and reduce_range != quant_overrides["reduce_range"]: return ( False, "Channel reduce_range value for tensor '{tensor_name}' does not match at index {index}.", ) # If override scale/zp, must do so for all channels. chan_has_scale_zp = "scale" in quant_overrides and "zero_point" in quant_overrides if has_scale_zp and not chan_has_scale_zp: return ( False, "Per-channel overrides that specify scale/zero_point must do so for all channels, " f"but tensor '{tensor_name}' is missing them at index {index}.", ) if chan_has_scale_zp: keys = self.keys_unsupported_with_scale_zp.intersection(set(quant_overrides)) if keys: return ( False, f"Tensor override option(s) [{', '.join(keys)}] are invalid with 'scale' and 'zero_point'", ) # If override rmin/rmax, must do so for all channels. chan_has_rmin_rmax = "rmin" in quant_overrides and "rmax" in quant_overrides if has_rmin_rmax and not chan_has_rmin_rmax: return ( False, "Per-channel overrides that specify rmin/rmax must do so for all channels, " f"but tensor '{tensor_name}' is missing them at index {index}.", ) return True, None def is_valid( self, initializers: dict[str, onnx.TensorProto], activation_names: set[str], default_activation_qtype, ) -> tuple[bool, str | None]: self.quant_types = set() # Validate that compatible/valid overrides are provided. if self.overrides: for tensor_name, quant_overrides_list in self.overrides.items(): if tensor_name not in initializers and tensor_name not in activation_names: return False, f"Tensor '{tensor_name}' in TensorQuantOverrides is not present in the model" if not isinstance(quant_overrides_list, list): return False, f"Tensor quantization overrides for '{tensor_name}' are not in a list" if not quant_overrides_list: continue if not isinstance(quant_overrides_list[0], dict): return False, f"Tensor quantization overrides at index 0 for '{tensor_name}' are not in a dict" if not quant_overrides_list[0]: continue axis = quant_overrides_list[0].get("axis") is_per_channel = len(quant_overrides_list) > 1 or axis is not None if is_per_channel: return self._is_valid_per_channel(initializers, tensor_name, quant_overrides_list) return self._is_valid_per_tensor( initializers, default_activation_qtype, tensor_name, quant_overrides_list[0] ) return True, None def update_tensor_overrides( self, tensor_name: str, new_vals: dict[str, Any], channels: list[int] | None = None, overwrite: bool = True, ) -> bool: if not new_vals: return False channels = set(channels) if channels is not None else None have_overrides = self.overrides.get(tensor_name) # If `overwrite` is False, check if we would overwrite anything. do_update = True if not overwrite and have_overrides: for channel, overrides in enumerate(self.overrides[tensor_name]): if channels is not None and channel not in channels: continue if set(new_vals).intersection(set(overrides)): do_update = False break # Do the update if `overwrite` is True or if nothing is overwritten (do not want partial overwrites). if do_update: if not have_overrides: self.overrides[tensor_name] = [{}] for channel, overrides in enumerate(self.overrides[tensor_name]): if channels is not None and channel not in channels: continue overrides.update(new_vals) return do_update def get_node_output_qtype_info( self, output_name: str, default_qtype: QuantType | None, default_symmetric: bool | None = None, ) -> QuantTypeInfo: # Outputs are activations, which do not support 'reduce_range' or 'axis' if output_name not in self.overrides: return QuantTypeInfo(default_qtype, default_symmetric) tensor_overrides = self.overrides[output_name][0] return QuantTypeInfo( tensor_overrides.get("quant_type", default_qtype), tensor_overrides.get("symmetric", default_symmetric), ) def get_node_input_qtype_info( self, input_name: str, node_name: str, default_qtype: QuantType | None, default_symmetric: bool | None = None, default_reduce_range: bool | None = None, ) -> QuantTypeInfo: if input_name not in self.overrides or not self.overrides[input_name]: return QuantTypeInfo(default_qtype, default_symmetric, default_reduce_range) # Get the first overrides dict in the list. This works for both per-tensor and per-channel # quantization because all channels must use the same quant type. tensor_overrides = self.overrides[input_name][0] producer_type = tensor_overrides.get("quant_type", default_qtype) if "convert" not in tensor_overrides: return QuantTypeInfo( producer_type, tensor_overrides.get("symmetric", default_symmetric), tensor_overrides.get("reduce_range", default_reduce_range), tensor_overrides.get("axis"), ) # This tensor is converted. Check if the node gets the original qtype or the converted qtype. convert_dict = tensor_overrides["convert"] qtype_info = QuantTypeInfo( producer_type, convert_dict.get("symmetric", default_symmetric), # Converted tensors are not initializers, so do not have 'axis' or 'reduce_range'. ) # Check if all nodes receive the converted type (i.e., recv_nodes is None) or this node # is in the list of consumers (recv_nodes). if ("recv_nodes" not in convert_dict) or (node_name in convert_dict["recv_nodes"]): qtype_info.quant_type = convert_dict["quant_type"] return qtype_info def pprint_str(self, indent=None) -> str: return json.dumps(self.overrides, default=str, indent=indent) def empty(self) -> bool: return not self.overrides def get_dict(self) -> dict[str, list[dict[str, Any]]]: return self.overrides # Required implementations of abstract methods in collections.abc.MutableMapping # so that this class can be used like a dict. def __setitem__(self, key: str, value: list[dict]): self.overrides[key] = value def __getitem__(self, key: str) -> list[dict]: return self.overrides[key] def __delitem__(self, key: str): del self.overrides[key] def __iter__(self): return iter(self.overrides) def __len__(self): return len(self.overrides) def __str__(self) -> str: return str(self.overrides) def __repr__(self) -> str: return f"{super().__repr__()}, TensorQuantOverridesHelper({self.overrides})"
Memory