import copy
import logging
from collections import OrderedDict
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
import numpy
import torch
from onnxruntime import InferenceSession, RunOptions
# Type alias
ShapeDict = Mapping[str, Union[Tuple, List[int]]]
logger = logging.getLogger(__name__)
class TypeHelper:
@staticmethod
def get_input_type(ort_session: InferenceSession, name: str) -> str:
for _i, input in enumerate(ort_session.get_inputs()):
if input.name == name:
return input.type
raise ValueError(f"input name {name} not found")
@staticmethod
def get_output_type(ort_session, name: str) -> str:
for _i, output in enumerate(ort_session.get_outputs()):
if output.name == name:
return output.type
raise ValueError(f"output name {name} not found")
@staticmethod
def ort_type_to_numpy_type(ort_type: str):
ort_type_to_numpy_type_map = {
"tensor(int64)": numpy.longlong,
"tensor(int32)": numpy.intc,
"tensor(float)": numpy.float32,
"tensor(float16)": numpy.float16,
"tensor(bool)": bool,
}
if ort_type not in ort_type_to_numpy_type_map:
raise ValueError(f"{ort_type} not found in map")
return ort_type_to_numpy_type_map[ort_type]
@staticmethod
def ort_type_to_torch_type(ort_type: str):
ort_type_to_torch_type_map = {
"tensor(int64)": torch.int64,
"tensor(int32)": torch.int32,
"tensor(float)": torch.float32,
"tensor(float16)": torch.float16,
"tensor(bool)": torch.bool,
}
if ort_type not in ort_type_to_torch_type_map:
raise ValueError(f"{ort_type} not found in map")
return ort_type_to_torch_type_map[ort_type]
@staticmethod
def numpy_type_to_torch_type(numpy_type: numpy.dtype):
numpy_type_to_torch_type_map = {
numpy.longlong: torch.int64,
numpy.intc: torch.int32,
numpy.int32: torch.int32,
numpy.float32: torch.float32,
numpy.float16: torch.float16,
bool: torch.bool,
}
if numpy_type not in numpy_type_to_torch_type_map:
raise ValueError(f"{numpy_type} not found in map")
return numpy_type_to_torch_type_map[numpy_type]
@staticmethod
def torch_type_to_numpy_type(torch_type: torch.dtype):
torch_type_to_numpy_type_map = {
torch.int64: numpy.longlong,
torch.int32: numpy.intc,
torch.float32: numpy.float32,
torch.float16: numpy.float16,
torch.bool: bool,
}
if torch_type not in torch_type_to_numpy_type_map:
raise ValueError(f"{torch_type} not found in map")
return torch_type_to_numpy_type_map[torch_type]
@staticmethod
def get_io_numpy_type_map(ort_session: InferenceSession) -> Dict[str, numpy.dtype]:
"""Create a mapping from input/output name to numpy data type"""
name_to_numpy_type = {}
for input in ort_session.get_inputs():
name_to_numpy_type[input.name] = TypeHelper.ort_type_to_numpy_type(input.type)
for output in ort_session.get_outputs():
name_to_numpy_type[output.name] = TypeHelper.ort_type_to_numpy_type(output.type)
return name_to_numpy_type
class IOBindingHelper:
@staticmethod
def get_output_buffers(ort_session: InferenceSession, output_shapes, device):
"""Returns a dictionary of output name as key, and 1D tensor as value. The tensor has enough space for given shape."""
output_buffers = {}
for name, shape in output_shapes.items():
ort_type = TypeHelper.get_output_type(ort_session, name)
torch_type = TypeHelper.ort_type_to_torch_type(ort_type)
output_buffers[name] = torch.empty(numpy.prod(shape), dtype=torch_type, device=device)
return output_buffers
@staticmethod
def prepare_io_binding(
ort_session,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor,
past: List[torch.Tensor],
output_buffers,
output_shapes,
name_to_np_type=None,
):
"""Returnas IO binding object for a session."""
if name_to_np_type is None:
name_to_np_type = TypeHelper.get_io_numpy_type_map(ort_session)
# Bind inputs and outputs to onnxruntime session
io_binding = ort_session.io_binding()
# Bind inputs
assert input_ids.is_contiguous()
io_binding.bind_input(
"input_ids",
input_ids.device.type,
0,
name_to_np_type["input_ids"],
list(input_ids.size()),
input_ids.data_ptr(),
)
if past is not None:
for i, past_i in enumerate(past):
assert past_i.is_contiguous()
data_ptr = past_i.data_ptr()
if data_ptr == 0:
# When past_sequence_length is 0, its data_ptr will be zero. IO Binding asserts that data_ptr shall not be zero.
# Here we workaround and pass data pointer of input_ids. Actual data is not used for past so it does not matter.
data_ptr = input_ids.data_ptr()
io_binding.bind_input(
f"past_{i}",
past_i.device.type,
0,
name_to_np_type[f"past_{i}"],
list(past_i.size()),
data_ptr,
)
if attention_mask is not None:
assert attention_mask.is_contiguous()
io_binding.bind_input(
"attention_mask",
attention_mask.device.type,
0,
name_to_np_type["attention_mask"],
list(attention_mask.size()),
attention_mask.data_ptr(),
)
if position_ids is not None:
assert position_ids.is_contiguous()
io_binding.bind_input(
"position_ids",
position_ids.device.type,
0,
name_to_np_type["position_ids"],
list(position_ids.size()),
position_ids.data_ptr(),
)
# Bind outputs
for output in ort_session.get_outputs():
output_name = output.name
output_buffer = output_buffers[output_name]
logger.debug(f"{output_name} device type={output_buffer.device.type} shape={list(output_buffer.size())}")
io_binding.bind_output(
output_name,
output_buffer.device.type,
0,
name_to_np_type[output_name],
output_shapes[output_name],
output_buffer.data_ptr(),
)
return io_binding
@staticmethod
def get_outputs_from_io_binding_buffer(ort_session, output_buffers, output_shapes, return_numpy=True):
"""Copy results to cpu. Returns a list of numpy array."""
ort_outputs = []
for output in ort_session.get_outputs():
output_name = output.name
buffer = output_buffers[output_name]
shape = output_shapes[output_name]
copy_tensor = buffer[0 : numpy.prod(shape)].reshape(shape).clone().detach()
if return_numpy:
ort_outputs.append(copy_tensor.cpu().numpy())
else:
ort_outputs.append(copy_tensor)
return ort_outputs
class CudaSession:
"""Inference Session with IO Binding for ONNX Runtime CUDA or TensorRT provider"""
def __init__(self, ort_session: InferenceSession, device: torch.device, enable_cuda_graph=False):
self.ort_session = ort_session
self.input_names = [input.name for input in self.ort_session.get_inputs()]
self.output_names = [output.name for output in self.ort_session.get_outputs()]
self.io_name_to_numpy_type = TypeHelper.get_io_numpy_type_map(self.ort_session)
self.io_binding = self.ort_session.io_binding()
self.enable_cuda_graph = enable_cuda_graph
self.input_tensors = OrderedDict()
self.output_tensors = OrderedDict()
self.device = device
# Pairs of input and output names that share the same buffer.
self.buffer_sharing: Dict[str, str] = {}
def set_buffer_sharing(self, input_name: str, output_name: str):
assert input_name in self.input_names
assert output_name in self.output_names
self.buffer_sharing[input_name] = output_name
self.buffer_sharing[output_name] = input_name
def __del__(self):
del self.input_tensors
del self.output_tensors
del self.io_binding
def bind_input_and_buffer_sharing(self, name: str, tensor: torch.Tensor):
device_id = tensor.device.index if tensor.device.index is not None else 0
tensor_shape = [1] if len(tensor.shape) == 0 else list(tensor.shape)
self.io_binding.bind_input(
name,
tensor.device.type,
device_id,
self.io_name_to_numpy_type[name],
tensor_shape,
tensor.data_ptr(),
)
if name in self.buffer_sharing:
self.io_binding.bind_output(
self.buffer_sharing[name],
tensor.device.type,
device_id,
self.io_name_to_numpy_type[name],
tensor_shape,
tensor.data_ptr(),
)
self.output_tensors[self.buffer_sharing[name]] = tensor
def allocate_buffers(self, shape_dict: ShapeDict):
"""Allocate tensors for I/O Binding"""
if self.enable_cuda_graph:
for name, shape in shape_dict.items():
if name in self.input_names:
# Reuse allocated buffer when the shape is same
if name in self.input_tensors:
if tuple(self.input_tensors[name].shape) == tuple(shape):
continue
raise RuntimeError("Expect static input shape for cuda graph")
numpy_dtype = self.io_name_to_numpy_type[name]
tensor = torch.empty(tuple(shape), dtype=TypeHelper.numpy_type_to_torch_type(numpy_dtype)).to(
device=self.device
)
self.input_tensors[name] = tensor
self.bind_input_and_buffer_sharing(name, tensor)
for name, shape in shape_dict.items():
if name in self.output_names:
# Reuse allocated buffer when the shape is same
if name in self.output_tensors and tuple(self.output_tensors[name].shape) == tuple(shape):
continue
if name in self.buffer_sharing:
continue
numpy_dtype = self.io_name_to_numpy_type[name]
tensor = torch.empty(tuple(shape), dtype=TypeHelper.numpy_type_to_torch_type(numpy_dtype)).to(
device=self.device
)
self.output_tensors[name] = tensor
self.io_binding.bind_output(
name,
tensor.device.type,
tensor.device.index if tensor.device.index is not None else 0,
numpy_dtype,
list(tensor.size()),
tensor.data_ptr(),
)
def infer(self, feed_dict: Dict[str, torch.Tensor], run_options: RunOptions = None, synchronize: bool = True):
"""Bind input tensors and run inference"""
for name, tensor in feed_dict.items():
assert isinstance(tensor, torch.Tensor) and tensor.is_contiguous()
if name in self.input_names:
if self.enable_cuda_graph:
assert self.input_tensors[name].nelement() == tensor.nelement()
assert self.input_tensors[name].dtype == tensor.dtype
assert tensor.device.type == "cuda"
self.input_tensors[name].copy_(tensor)
else:
self.bind_input_and_buffer_sharing(name, tensor)
if synchronize:
self.io_binding.synchronize_inputs()
self.ort_session.run_with_iobinding(self.io_binding, run_options)
self.io_binding.synchronize_outputs()
else:
self.ort_session.run_with_iobinding(self.io_binding, run_options)
return self.output_tensors
@staticmethod
def get_cuda_provider_options(device_id: int, enable_cuda_graph: bool, stream: int = 0) -> Dict[str, Any]:
options = {
"device_id": device_id,
"arena_extend_strategy": "kSameAsRequested",
"enable_cuda_graph": enable_cuda_graph,
}
# Stream is address of a CUDA stream. 0 means the default stream.
if stream != 0:
options["user_compute_stream"] = str(stream)
return options
class GpuBinding(CudaSession):
def __init__(
self,
ort_session: InferenceSession,
device: torch.device,
shape_dict: ShapeDict,
enable_gpu_graph: bool = False,
gpu_graph_id: int = -1,
stream: int = 0,
buffer_sharing: Optional[Dict[str, str]] = None,
):
super().__init__(ort_session, device, enable_gpu_graph)
if buffer_sharing:
for input_name, output_name in buffer_sharing.items():
self.set_buffer_sharing(input_name, output_name)
self.allocate_buffers(shape_dict)
self.gpu_graph_id = gpu_graph_id
# For cuda graph, we need to keep a copy of shape_dict to check if the shape is same in inference later.
self.shape_dict = copy.deepcopy(shape_dict) if enable_gpu_graph else None
self.stream = stream
# The gpu graph id of last run. It will be saved to image metadata.
self.last_run_gpu_graph_id = None
def get_run_options(self, disable_cuda_graph_in_run: bool = False) -> RunOptions:
options = RunOptions()
gpu_graph_id = -1 if disable_cuda_graph_in_run else self.gpu_graph_id
options.add_run_config_entry("gpu_graph_id", str(gpu_graph_id))
self.last_run_gpu_graph_id = gpu_graph_id
return options
def infer(self, feed_dict: Dict[str, torch.Tensor], disable_cuda_graph_in_run: bool = False):
run_options = self.get_run_options(disable_cuda_graph_in_run)
if self.stream:
run_options.add_run_config_entry("disable_synchronize_execution_providers", "1")
return super().infer(feed_dict, run_options)
class GpuBindingManager:
"""A manager for I/O bindings that support multiple CUDA Graphs.
One cuda graph is reused for same input shape. Automatically add a new cuda graph for new input shape.
"""
def __init__(self, ort_session: InferenceSession, device: torch.device, stream: int = 0, max_cuda_graphs: int = 1):
self.ort_session = ort_session
self.device = device
# Binding supports cuda graphs. For a binding, it is able to disable cuda graph for a specific run.
self.graph_bindings = []
# Binding for not using cuda graph.
self.no_graph_binding = None
self.stream = stream
self.max_cuda_graphs = max_cuda_graphs
def get_binding(
self,
shape_dict: ShapeDict,
use_cuda_graph: bool = False,
buffer_sharing: Optional[Dict[str, str]] = None,
) -> GpuBinding:
for gpu_graph_binding in self.graph_bindings:
# Found a cuda graph that captured with the same shape
if gpu_graph_binding.shape_dict == shape_dict:
return gpu_graph_binding
# Reached the maximum number of cuda graphs. Return a binding without cuda graph.
if len(self.graph_bindings) >= self.max_cuda_graphs or (not use_cuda_graph):
if self.no_graph_binding is None:
self.no_graph_binding = GpuBinding(
self.ort_session, self.device, shape_dict, stream=self.stream, buffer_sharing=buffer_sharing
)
else:
self.no_graph_binding.allocate_buffers(shape_dict)
return self.no_graph_binding
# This is a new input shape, create a new cuda graph
gpu_graph_binding = GpuBinding(
self.ort_session,
self.device,
shape_dict,
enable_gpu_graph=True,
gpu_graph_id=len(self.graph_bindings),
stream=self.stream,
buffer_sharing=buffer_sharing,
)
self.graph_bindings.append(gpu_graph_binding)
return gpu_graph_binding