# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
import argparse
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import numpy as np
from PIL import Image
root_dir = Path(__file__).resolve().parent.parent
InputType = Union[str, np.ndarray, bytes, Path, Image.Image]
def update_model_path(config: Dict[str, Any]) -> Dict[str, Any]:
key = "model_path"
config["Det"][key] = str(root_dir / config["Det"][key])
config["Rec"][key] = str(root_dir / config["Rec"][key])
config["Cls"][key] = str(root_dir / config["Cls"][key])
return config
def init_args():
parser = argparse.ArgumentParser()
parser.add_argument("-img", "--img_path", type=str, default=None, required=True)
parser.add_argument("-p", "--print_cost", action="store_true", default=False)
global_group = parser.add_argument_group(title="Global")
global_group.add_argument("--text_score", type=float, default=0.5)
global_group.add_argument("--no_det", action="store_true", default=False)
global_group.add_argument("--no_cls", action="store_true", default=False)
global_group.add_argument("--no_rec", action="store_true", default=False)
global_group.add_argument("--print_verbose", action="store_true", default=False)
global_group.add_argument("--min_height", type=int, default=30)
global_group.add_argument("--width_height_ratio", type=int, default=8)
global_group.add_argument("--intra_op_num_threads", type=int, default=-1)
global_group.add_argument("--inter_op_num_threads", type=int, default=-1)
det_group = parser.add_argument_group(title="Det")
det_group.add_argument("--det_use_cuda", action="store_true", default=False)
det_group.add_argument("--det_use_dml", action="store_true", default=False)
det_group.add_argument("--det_model_path", type=str, default=None)
det_group.add_argument("--det_limit_side_len", type=float, default=736)
det_group.add_argument(
"--det_limit_type", type=str, default="min", choices=["max", "min"]
)
det_group.add_argument("--det_thresh", type=float, default=0.3)
det_group.add_argument("--det_box_thresh", type=float, default=0.5)
det_group.add_argument("--det_unclip_ratio", type=float, default=1.6)
det_group.add_argument(
"--det_donot_use_dilation", action="store_true", default=False
)
det_group.add_argument(
"--det_score_mode", type=str, default="fast", choices=["slow", "fast"]
)
cls_group = parser.add_argument_group(title="Cls")
cls_group.add_argument("--cls_use_cuda", action="store_true", default=False)
cls_group.add_argument("--cls_use_dml", action="store_true", default=False)
cls_group.add_argument("--cls_model_path", type=str, default=None)
cls_group.add_argument("--cls_image_shape", type=list, default=[3, 48, 192])
cls_group.add_argument("--cls_label_list", type=list, default=["0", "180"])
cls_group.add_argument("--cls_batch_num", type=int, default=6)
cls_group.add_argument("--cls_thresh", type=float, default=0.9)
rec_group = parser.add_argument_group(title="Rec")
rec_group.add_argument("--rec_use_cuda", action="store_true", default=False)
rec_group.add_argument("--rec_use_dml", action="store_true", default=False)
rec_group.add_argument("--rec_model_path", type=str, default=None)
rec_group.add_argument("--rec_keys_path", type=str, default=None)
rec_group.add_argument("--rec_img_shape", type=list, default=[3, 48, 320])
rec_group.add_argument("--rec_batch_num", type=int, default=6)
vis_group = parser.add_argument_group(title="Visual Result")
vis_group.add_argument("-vis", "--vis_res", action="store_true", default=False)
vis_group.add_argument(
"--vis_font_path",
type=str,
default=None,
help="When -vis is True, the font_path must have value.",
)
vis_group.add_argument(
"--vis_save_path",
type=str,
default=".",
help="The directory of saving the vis image.",
)
args = parser.parse_args()
return args
class UpdateParameters:
def __init__(self) -> None:
pass
def parse_kwargs(self, **kwargs):
global_dict, det_dict, cls_dict, rec_dict = {}, {}, {}, {}
for k, v in kwargs.items():
if k.startswith("det"):
k = k.split("det_")[1]
if k == "donot_use_dilation":
k = "use_dilation"
v = not v
det_dict[k] = v
elif k.startswith("cls"):
cls_dict[k] = v
elif k.startswith("rec"):
rec_dict[k] = v
else:
global_dict[k] = v
return global_dict, det_dict, cls_dict, rec_dict
def __call__(self, config, **kwargs):
global_dict, det_dict, cls_dict, rec_dict = self.parse_kwargs(**kwargs)
new_config = {
"Global": self.update_global_params(config["Global"], global_dict),
"Det": self.update_params(
config["Det"],
det_dict,
"det_",
["det_model_path", "det_use_cuda", "det_use_dml"],
),
"Cls": self.update_params(
config["Cls"],
cls_dict,
"cls_",
["cls_label_list", "cls_model_path", "cls_use_cuda", "cls_use_dml"],
),
"Rec": self.update_params(
config["Rec"],
rec_dict,
"rec_",
["rec_model_path", "rec_use_cuda", "rec_use_dml"],
),
}
update_params = ["intra_op_num_threads", "inter_op_num_threads"]
new_config = self.update_global_to_module(
config, update_params, src="Global", dsts=["Det", "Cls", "Rec"]
)
return new_config
def update_global_to_module(
self, config, params: List[str], src: str, dsts: List[str]
):
for dst in dsts:
for param in params:
config[dst].update({param: config[src][param]})
return config
def update_global_params(self, config, global_dict):
if global_dict:
config.update(global_dict)
return config
def update_params(
self,
config,
param_dict: Dict[str, str],
prefix: str,
need_remove_prefix: Optional[List[str]] = None,
):
if not param_dict:
return config
filter_dict = self.remove_prefix(param_dict, prefix, need_remove_prefix)
model_path = filter_dict.get("model_path", None)
if not model_path:
filter_dict["model_path"] = str(root_dir / config["model_path"])
config.update(filter_dict)
return config
@staticmethod
def remove_prefix(
config: Dict[str, str],
prefix: str,
need_remove_prefix: Optional[List[str]] = None,
) -> Dict[str, str]:
if not need_remove_prefix:
return config
new_rec_dict = {}
for k, v in config.items():
if k in need_remove_prefix:
k = k.split(prefix)[1]
new_rec_dict[k] = v
return new_rec_dict