recognizer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. """
  2. Copyright (c) Alibaba, Inc. and its affiliates.
  3. """
  4. import math
  5. import os
  6. import time
  7. import traceback
  8. import cv2
  9. import numpy as np
  10. import torch
  11. import torch.nn.functional as F
  12. from easydict import EasyDict as edict
  13. from sorawm.iopaint.model.anytext.ocr_recog.RecModel import RecModel
  14. def min_bounding_rect(img):
  15. ret, thresh = cv2.threshold(img, 127, 255, 0)
  16. contours, hierarchy = cv2.findContours(
  17. thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
  18. )
  19. if len(contours) == 0:
  20. print("Bad contours, using fake bbox...")
  21. return np.array([[0, 0], [100, 0], [100, 100], [0, 100]])
  22. max_contour = max(contours, key=cv2.contourArea)
  23. rect = cv2.minAreaRect(max_contour)
  24. box = cv2.boxPoints(rect)
  25. box = np.int0(box)
  26. # sort
  27. x_sorted = sorted(box, key=lambda x: x[0])
  28. left = x_sorted[:2]
  29. right = x_sorted[2:]
  30. left = sorted(left, key=lambda x: x[1])
  31. (tl, bl) = left
  32. right = sorted(right, key=lambda x: x[1])
  33. (tr, br) = right
  34. if tl[1] > bl[1]:
  35. (tl, bl) = (bl, tl)
  36. if tr[1] > br[1]:
  37. (tr, br) = (br, tr)
  38. return np.array([tl, tr, br, bl])
  39. def create_predictor(model_dir=None, model_lang="ch", is_onnx=False):
  40. model_file_path = model_dir
  41. if model_file_path is not None and not os.path.exists(model_file_path):
  42. raise ValueError("not find model file path {}".format(model_file_path))
  43. if is_onnx:
  44. import onnxruntime as ort
  45. sess = ort.InferenceSession(
  46. model_file_path, providers=["CPUExecutionProvider"]
  47. ) # 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'
  48. return sess
  49. else:
  50. if model_lang == "ch":
  51. n_class = 6625
  52. elif model_lang == "en":
  53. n_class = 97
  54. else:
  55. raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}")
  56. rec_config = edict(
  57. in_channels=3,
  58. backbone=edict(
  59. type="MobileNetV1Enhance",
  60. scale=0.5,
  61. last_conv_stride=[1, 2],
  62. last_pool_type="avg",
  63. ),
  64. neck=edict(
  65. type="SequenceEncoder",
  66. encoder_type="svtr",
  67. dims=64,
  68. depth=2,
  69. hidden_dims=120,
  70. use_guide=True,
  71. ),
  72. head=edict(
  73. type="CTCHead",
  74. fc_decay=0.00001,
  75. out_channels=n_class,
  76. return_feats=True,
  77. ),
  78. )
  79. rec_model = RecModel(rec_config)
  80. if model_file_path is not None:
  81. rec_model.load_state_dict(torch.load(model_file_path, map_location="cpu"))
  82. rec_model.eval()
  83. return rec_model.eval()
  84. def _check_image_file(path):
  85. img_end = {"jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff"}
  86. return any([path.lower().endswith(e) for e in img_end])
  87. def get_image_file_list(img_file):
  88. imgs_lists = []
  89. if img_file is None or not os.path.exists(img_file):
  90. raise Exception("not found any img file in {}".format(img_file))
  91. if os.path.isfile(img_file) and _check_image_file(img_file):
  92. imgs_lists.append(img_file)
  93. elif os.path.isdir(img_file):
  94. for single_file in os.listdir(img_file):
  95. file_path = os.path.join(img_file, single_file)
  96. if os.path.isfile(file_path) and _check_image_file(file_path):
  97. imgs_lists.append(file_path)
  98. if len(imgs_lists) == 0:
  99. raise Exception("not found any img file in {}".format(img_file))
  100. imgs_lists = sorted(imgs_lists)
  101. return imgs_lists
  102. class TextRecognizer(object):
  103. def __init__(self, args, predictor):
  104. self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
  105. self.rec_batch_num = args.rec_batch_num
  106. self.predictor = predictor
  107. self.chars = self.get_char_dict(args.rec_char_dict_path)
  108. self.char2id = {x: i for i, x in enumerate(self.chars)}
  109. self.is_onnx = not isinstance(self.predictor, torch.nn.Module)
  110. self.use_fp16 = args.use_fp16
  111. # img: CHW
  112. def resize_norm_img(self, img, max_wh_ratio):
  113. imgC, imgH, imgW = self.rec_image_shape
  114. assert imgC == img.shape[0]
  115. imgW = int((imgH * max_wh_ratio))
  116. h, w = img.shape[1:]
  117. ratio = w / float(h)
  118. if math.ceil(imgH * ratio) > imgW:
  119. resized_w = imgW
  120. else:
  121. resized_w = int(math.ceil(imgH * ratio))
  122. resized_image = torch.nn.functional.interpolate(
  123. img.unsqueeze(0),
  124. size=(imgH, resized_w),
  125. mode="bilinear",
  126. align_corners=True,
  127. )
  128. resized_image /= 255.0
  129. resized_image -= 0.5
  130. resized_image /= 0.5
  131. padding_im = torch.zeros((imgC, imgH, imgW), dtype=torch.float32).to(img.device)
  132. padding_im[:, :, 0:resized_w] = resized_image[0]
  133. return padding_im
  134. # img_list: list of tensors with shape chw 0-255
  135. def pred_imglist(self, img_list, show_debug=False, is_ori=False):
  136. img_num = len(img_list)
  137. assert img_num > 0
  138. # Calculate the aspect ratio of all text bars
  139. width_list = []
  140. for img in img_list:
  141. width_list.append(img.shape[2] / float(img.shape[1]))
  142. # Sorting can speed up the recognition process
  143. indices = torch.from_numpy(np.argsort(np.array(width_list)))
  144. batch_num = self.rec_batch_num
  145. preds_all = [None] * img_num
  146. preds_neck_all = [None] * img_num
  147. for beg_img_no in range(0, img_num, batch_num):
  148. end_img_no = min(img_num, beg_img_no + batch_num)
  149. norm_img_batch = []
  150. imgC, imgH, imgW = self.rec_image_shape[:3]
  151. max_wh_ratio = imgW / imgH
  152. for ino in range(beg_img_no, end_img_no):
  153. h, w = img_list[indices[ino]].shape[1:]
  154. if h > w * 1.2:
  155. img = img_list[indices[ino]]
  156. img = torch.transpose(img, 1, 2).flip(dims=[1])
  157. img_list[indices[ino]] = img
  158. h, w = img.shape[1:]
  159. # wh_ratio = w * 1.0 / h
  160. # max_wh_ratio = max(max_wh_ratio, wh_ratio) # comment to not use different ratio
  161. for ino in range(beg_img_no, end_img_no):
  162. norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
  163. if self.use_fp16:
  164. norm_img = norm_img.half()
  165. norm_img = norm_img.unsqueeze(0)
  166. norm_img_batch.append(norm_img)
  167. norm_img_batch = torch.cat(norm_img_batch, dim=0)
  168. if show_debug:
  169. for i in range(len(norm_img_batch)):
  170. _img = norm_img_batch[i].permute(1, 2, 0).detach().cpu().numpy()
  171. _img = (_img + 0.5) * 255
  172. _img = _img[:, :, ::-1]
  173. file_name = f"{indices[beg_img_no + i]}"
  174. file_name = file_name + "_ori" if is_ori else file_name
  175. cv2.imwrite(file_name + ".jpg", _img)
  176. if self.is_onnx:
  177. input_dict = {}
  178. input_dict[self.predictor.get_inputs()[0].name] = (
  179. norm_img_batch.detach().cpu().numpy()
  180. )
  181. outputs = self.predictor.run(None, input_dict)
  182. preds = {}
  183. preds["ctc"] = torch.from_numpy(outputs[0])
  184. preds["ctc_neck"] = [torch.zeros(1)] * img_num
  185. else:
  186. preds = self.predictor(norm_img_batch)
  187. for rno in range(preds["ctc"].shape[0]):
  188. preds_all[indices[beg_img_no + rno]] = preds["ctc"][rno]
  189. preds_neck_all[indices[beg_img_no + rno]] = preds["ctc_neck"][rno]
  190. return torch.stack(preds_all, dim=0), torch.stack(preds_neck_all, dim=0)
  191. def get_char_dict(self, character_dict_path):
  192. character_str = []
  193. with open(character_dict_path, "rb") as fin:
  194. lines = fin.readlines()
  195. for line in lines:
  196. line = line.decode("utf-8").strip("\n").strip("\r\n")
  197. character_str.append(line)
  198. dict_character = list(character_str)
  199. dict_character = ["sos"] + dict_character + [" "] # eos is space
  200. return dict_character
  201. def get_text(self, order):
  202. char_list = [self.chars[text_id] for text_id in order]
  203. return "".join(char_list)
  204. def decode(self, mat):
  205. text_index = mat.detach().cpu().numpy().argmax(axis=1)
  206. ignored_tokens = [0]
  207. selection = np.ones(len(text_index), dtype=bool)
  208. selection[1:] = text_index[1:] != text_index[:-1]
  209. for ignored_token in ignored_tokens:
  210. selection &= text_index != ignored_token
  211. return text_index[selection], np.where(selection)[0]
  212. def get_ctcloss(self, preds, gt_text, weight):
  213. if not isinstance(weight, torch.Tensor):
  214. weight = torch.tensor(weight).to(preds.device)
  215. ctc_loss = torch.nn.CTCLoss(reduction="none")
  216. log_probs = preds.log_softmax(dim=2).permute(1, 0, 2) # NTC-->TNC
  217. targets = []
  218. target_lengths = []
  219. for t in gt_text:
  220. targets += [self.char2id.get(i, len(self.chars) - 1) for i in t]
  221. target_lengths += [len(t)]
  222. targets = torch.tensor(targets).to(preds.device)
  223. target_lengths = torch.tensor(target_lengths).to(preds.device)
  224. input_lengths = torch.tensor([log_probs.shape[0]] * (log_probs.shape[1])).to(
  225. preds.device
  226. )
  227. loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
  228. loss = loss / input_lengths * weight
  229. return loss
  230. def main():
  231. rec_model_dir = "./ocr_weights/ppv3_rec.pth"
  232. predictor = create_predictor(rec_model_dir)
  233. args = edict()
  234. args.rec_image_shape = "3, 48, 320"
  235. args.rec_char_dict_path = "./ocr_weights/ppocr_keys_v1.txt"
  236. args.rec_batch_num = 6
  237. text_recognizer = TextRecognizer(args, predictor)
  238. image_dir = "./test_imgs_cn"
  239. gt_text = ["韩国小馆"] * 14
  240. image_file_list = get_image_file_list(image_dir)
  241. valid_image_file_list = []
  242. img_list = []
  243. for image_file in image_file_list:
  244. img = cv2.imread(image_file)
  245. if img is None:
  246. print("error in loading image:{}".format(image_file))
  247. continue
  248. valid_image_file_list.append(image_file)
  249. img_list.append(torch.from_numpy(img).permute(2, 0, 1).float())
  250. try:
  251. tic = time.time()
  252. times = []
  253. for i in range(10):
  254. preds, _ = text_recognizer.pred_imglist(img_list) # get text
  255. preds_all = preds.softmax(dim=2)
  256. times += [(time.time() - tic) * 1000.0]
  257. tic = time.time()
  258. print(times)
  259. print(np.mean(times[1:]) / len(preds_all))
  260. weight = np.ones(len(gt_text))
  261. loss = text_recognizer.get_ctcloss(preds, gt_text, weight)
  262. for i in range(len(valid_image_file_list)):
  263. pred = preds_all[i]
  264. order, idx = text_recognizer.decode(pred)
  265. text = text_recognizer.get_text(order)
  266. print(
  267. f'{valid_image_file_list[i]}: pred/gt="{text}"/"{gt_text[i]}", loss={loss[i]:.2f}'
  268. )
  269. except Exception as E:
  270. print(traceback.format_exc(), E)
  271. if __name__ == "__main__":
  272. main()