# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
import os
import platform
import traceback
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union
import numpy as np
from onnxruntime import (
GraphOptimizationLevel,
InferenceSession,
SessionOptions,
get_available_providers,
get_device,
)
from .logger import get_logger
class EP(Enum):
CPU_EP = "CPUExecutionProvider"
CUDA_EP = "CUDAExecutionProvider"
DIRECTML_EP = "DmlExecutionProvider"
class OrtInferSession:
def __init__(self, config: Dict[str, Any]):
self.logger = get_logger("OrtInferSession")
model_path = config.get("model_path", None)
self._verify_model(model_path)
self.cfg_use_cuda = config.get("use_cuda", None)
self.cfg_use_dml = config.get("use_dml", None)
self.had_providers: List[str] = get_available_providers()
EP_list = self._get_ep_list()
sess_opt = self._init_sess_opts(config)
self.session = InferenceSession(
model_path,
sess_options=sess_opt,
providers=EP_list,
)
self._verify_providers()
@staticmethod
def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions:
sess_opt = SessionOptions()
sess_opt.log_severity_level = 4
sess_opt.enable_cpu_mem_arena = False
sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
cpu_nums = os.cpu_count()
intra_op_num_threads = config.get("intra_op_num_threads", -1)
if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums:
sess_opt.intra_op_num_threads = intra_op_num_threads
inter_op_num_threads = config.get("inter_op_num_threads", -1)
if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums:
sess_opt.inter_op_num_threads = inter_op_num_threads
return sess_opt
def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
cpu_provider_opts = {
"arena_extend_strategy": "kSameAsRequested",
}
EP_list = [(EP.CPU_EP.value, cpu_provider_opts)]
cuda_provider_opts = {
"device_id": 0,
"arena_extend_strategy": "kNextPowerOfTwo",
"cudnn_conv_algo_search": "EXHAUSTIVE",
"do_copy_in_default_stream": True,
}
self.use_cuda = self._check_cuda()
if self.use_cuda:
EP_list.insert(0, (EP.CUDA_EP.value, cuda_provider_opts))
self.use_directml = self._check_dml()
if self.use_directml:
self.logger.info(
"Windows 10 or above detected, try to use DirectML as primary provider"
)
directml_options = (
cuda_provider_opts if self.use_cuda else cpu_provider_opts
)
EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options))
return EP_list
def _check_cuda(self) -> bool:
if not self.cfg_use_cuda:
return False
cur_device = get_device()
if cur_device == "GPU" and EP.CUDA_EP.value in self.had_providers:
return True
self.logger.warning(
"%s is not in available providers (%s). Use %s inference by default.",
EP.CUDA_EP.value,
self.had_providers,
self.had_providers[0],
)
self.logger.info("!!!Recommend to use rapidocr_paddle for inference on GPU.")
self.logger.info(
"(For reference only) If you want to use GPU acceleration, you must do:"
)
self.logger.info(
"First, uninstall all onnxruntime pakcages in current environment."
)
self.logger.info(
"Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`."
)
self.logger.info(
"\tNote the onnxruntime-gpu version must match your cuda and cudnn version."
)
self.logger.info(
"\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-EP.html"
)
self.logger.info(
"Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']",
EP.CUDA_EP.value,
)
return False
def _check_dml(self) -> bool:
if not self.cfg_use_dml:
return False
cur_os = platform.system()
if cur_os != "Windows":
self.logger.warning(
"DirectML is only supported in Windows OS. The current OS is %s. Use %s inference by default.",
cur_os,
self.had_providers[0],
)
return False
cur_window_version = int(platform.release().split(".")[0])
if cur_window_version < 10:
self.logger.warning(
"DirectML is only supported in Windows 10 and above OS. The current Windows version is %s. Use %s inference by default.",
cur_window_version,
self.had_providers[0],
)
return False
if EP.DIRECTML_EP.value in self.had_providers:
return True
self.logger.warning(
"%s is not in available providers (%s). Use %s inference by default.",
EP.DIRECTML_EP.value,
self.had_providers,
self.had_providers[0],
)
self.logger.info("If you want to use DirectML acceleration, you must do:")
self.logger.info(
"First, uninstall all onnxruntime pakcages in current environment."
)
self.logger.info(
"Second, install onnxruntime-directml by `pip install onnxruntime-directml`"
)
self.logger.info(
"Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']",
EP.DIRECTML_EP.value,
)
return False
def _verify_providers(self):
session_providers = self.session.get_providers()
first_provider = session_providers[0]
if self.use_cuda and first_provider != EP.CUDA_EP.value:
self.logger.warning(
"%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.",
EP.CUDA_EP.value,
first_provider,
)
if self.use_directml and first_provider != EP.DIRECTML_EP.value:
self.logger.warning(
"%s is not available for current env, the inference part is automatically shifted to be executed under %s.",
EP.DIRECTML_EP.value,
first_provider,
)
def __call__(self, input_content: np.ndarray) -> np.ndarray:
input_dict = dict(zip(self.get_input_names(), [input_content]))
try:
return self.session.run(self.get_output_names(), input_dict)
except Exception as e:
error_info = traceback.format_exc()
raise ONNXRuntimeError(error_info) from e
def get_input_names(self) -> List[str]:
return [v.name for v in self.session.get_inputs()]
def get_output_names(self) -> List[str]:
return [v.name for v in self.session.get_outputs()]
def get_character_list(self, key: str = "character") -> List[str]:
meta_dict = self.session.get_modelmeta().custom_metadata_map
return meta_dict[key].splitlines()
def have_key(self, key: str = "character") -> bool:
meta_dict = self.session.get_modelmeta().custom_metadata_map
if key in meta_dict.keys():
return True
return False
@staticmethod
def _verify_model(model_path: Union[str, Path, None]):
if model_path is None:
raise ValueError("model_path is None!")
model_path = Path(model_path)
if not model_path.exists():
raise FileNotFoundError(f"{model_path} does not exists.")
if not model_path.is_file():
raise FileExistsError(f"{model_path} is not a file.")
class ONNXRuntimeError(Exception):
pass