retinaface.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. import cv2
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from PIL import Image
  7. from torchvision.models._utils import IntermediateLayerGetter as IntermediateLayerGetter
  8. from .align_trans import get_reference_facial_points, warp_and_crop_face
  9. from .retinaface_net import (
  10. FPN,
  11. SSH,
  12. MobileNetV1,
  13. make_bbox_head,
  14. make_class_head,
  15. make_landmark_head,
  16. )
  17. from .retinaface_utils import (
  18. PriorBox,
  19. batched_decode,
  20. batched_decode_landm,
  21. decode,
  22. decode_landm,
  23. py_cpu_nms,
  24. )
  25. def generate_config(network_name):
  26. cfg_mnet = {
  27. "name": "mobilenet0.25",
  28. "min_sizes": [[16, 32], [64, 128], [256, 512]],
  29. "steps": [8, 16, 32],
  30. "variance": [0.1, 0.2],
  31. "clip": False,
  32. "loc_weight": 2.0,
  33. "gpu_train": True,
  34. "batch_size": 32,
  35. "ngpu": 1,
  36. "epoch": 250,
  37. "decay1": 190,
  38. "decay2": 220,
  39. "image_size": 640,
  40. "return_layers": {"stage1": 1, "stage2": 2, "stage3": 3},
  41. "in_channel": 32,
  42. "out_channel": 64,
  43. }
  44. cfg_re50 = {
  45. "name": "Resnet50",
  46. "min_sizes": [[16, 32], [64, 128], [256, 512]],
  47. "steps": [8, 16, 32],
  48. "variance": [0.1, 0.2],
  49. "clip": False,
  50. "loc_weight": 2.0,
  51. "gpu_train": True,
  52. "batch_size": 24,
  53. "ngpu": 4,
  54. "epoch": 100,
  55. "decay1": 70,
  56. "decay2": 90,
  57. "image_size": 840,
  58. "return_layers": {"layer2": 1, "layer3": 2, "layer4": 3},
  59. "in_channel": 256,
  60. "out_channel": 256,
  61. }
  62. if network_name == "mobile0.25":
  63. return cfg_mnet
  64. elif network_name == "resnet50":
  65. return cfg_re50
  66. else:
  67. raise NotImplementedError(f"network_name={network_name}")
  68. class RetinaFace(nn.Module):
  69. def __init__(self, network_name="resnet50", half=False, phase="test", device=None):
  70. self.device = (
  71. torch.device("cuda" if torch.cuda.is_available() else "cpu")
  72. if device is None
  73. else device
  74. )
  75. super(RetinaFace, self).__init__()
  76. self.half_inference = half
  77. cfg = generate_config(network_name)
  78. self.backbone = cfg["name"]
  79. self.model_name = f"retinaface_{network_name}"
  80. self.cfg = cfg
  81. self.phase = phase
  82. self.target_size, self.max_size = 1600, 2150
  83. self.resize, self.scale, self.scale1 = 1.0, None, None
  84. self.mean_tensor = torch.tensor(
  85. [[[[104.0]], [[117.0]], [[123.0]]]], device=self.device
  86. )
  87. self.reference = get_reference_facial_points(default_square=True)
  88. # Build network.
  89. backbone = None
  90. if cfg["name"] == "mobilenet0.25":
  91. backbone = MobileNetV1()
  92. self.body = IntermediateLayerGetter(backbone, cfg["return_layers"])
  93. elif cfg["name"] == "Resnet50":
  94. import torchvision.models as models
  95. backbone = models.resnet50(pretrained=False)
  96. self.body = IntermediateLayerGetter(backbone, cfg["return_layers"])
  97. in_channels_stage2 = cfg["in_channel"]
  98. in_channels_list = [
  99. in_channels_stage2 * 2,
  100. in_channels_stage2 * 4,
  101. in_channels_stage2 * 8,
  102. ]
  103. out_channels = cfg["out_channel"]
  104. self.fpn = FPN(in_channels_list, out_channels)
  105. self.ssh1 = SSH(out_channels, out_channels)
  106. self.ssh2 = SSH(out_channels, out_channels)
  107. self.ssh3 = SSH(out_channels, out_channels)
  108. self.ClassHead = make_class_head(fpn_num=3, inchannels=cfg["out_channel"])
  109. self.BboxHead = make_bbox_head(fpn_num=3, inchannels=cfg["out_channel"])
  110. self.LandmarkHead = make_landmark_head(fpn_num=3, inchannels=cfg["out_channel"])
  111. self.to(self.device)
  112. self.eval()
  113. if self.half_inference:
  114. self.half()
  115. def forward(self, inputs):
  116. out = self.body(inputs)
  117. if self.backbone == "mobilenet0.25" or self.backbone == "Resnet50":
  118. out = list(out.values())
  119. # FPN
  120. fpn = self.fpn(out)
  121. # SSH
  122. feature1 = self.ssh1(fpn[0])
  123. feature2 = self.ssh2(fpn[1])
  124. feature3 = self.ssh3(fpn[2])
  125. features = [feature1, feature2, feature3]
  126. bbox_regressions = torch.cat(
  127. [self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1
  128. )
  129. classifications = torch.cat(
  130. [self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1
  131. )
  132. tmp = [self.LandmarkHead[i](feature) for i, feature in enumerate(features)]
  133. ldm_regressions = torch.cat(tmp, dim=1)
  134. if self.phase == "train":
  135. output = (bbox_regressions, classifications, ldm_regressions)
  136. else:
  137. output = (
  138. bbox_regressions,
  139. F.softmax(classifications, dim=-1),
  140. ldm_regressions,
  141. )
  142. return output
  143. def __detect_faces(self, inputs):
  144. # get scale
  145. height, width = inputs.shape[2:]
  146. self.scale = torch.tensor(
  147. [width, height, width, height], dtype=torch.float32, device=self.device
  148. )
  149. tmp = [
  150. width,
  151. height,
  152. width,
  153. height,
  154. width,
  155. height,
  156. width,
  157. height,
  158. width,
  159. height,
  160. ]
  161. self.scale1 = torch.tensor(tmp, dtype=torch.float32, device=self.device)
  162. # forawrd
  163. inputs = inputs.to(self.device)
  164. if self.half_inference:
  165. inputs = inputs.half()
  166. loc, conf, landmarks = self(inputs)
  167. # get priorbox
  168. priorbox = PriorBox(self.cfg, image_size=inputs.shape[2:])
  169. priors = priorbox.forward().to(self.device)
  170. return loc, conf, landmarks, priors
  171. # single image detection
  172. def transform(self, image, use_origin_size):
  173. # convert to opencv format
  174. if isinstance(image, Image.Image):
  175. image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
  176. image = image.astype(np.float32)
  177. # testing scale
  178. im_size_min = np.min(image.shape[0:2])
  179. im_size_max = np.max(image.shape[0:2])
  180. resize = float(self.target_size) / float(im_size_min)
  181. # prevent bigger axis from being more than max_size
  182. if np.round(resize * im_size_max) > self.max_size:
  183. resize = float(self.max_size) / float(im_size_max)
  184. resize = 1 if use_origin_size else resize
  185. # resize
  186. if resize != 1:
  187. image = cv2.resize(
  188. image, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR
  189. )
  190. # convert to torch.tensor format
  191. # image -= (104, 117, 123)
  192. image = image.transpose(2, 0, 1)
  193. image = torch.from_numpy(image).unsqueeze(0)
  194. return image, resize
  195. def detect_faces(
  196. self,
  197. image,
  198. conf_threshold=0.8,
  199. nms_threshold=0.4,
  200. use_origin_size=True,
  201. ):
  202. image, self.resize = self.transform(image, use_origin_size)
  203. image = image.to(self.device)
  204. if self.half_inference:
  205. image = image.half()
  206. image = image - self.mean_tensor
  207. loc, conf, landmarks, priors = self.__detect_faces(image)
  208. boxes = decode(loc.data.squeeze(0), priors.data, self.cfg["variance"])
  209. boxes = boxes * self.scale / self.resize
  210. boxes = boxes.cpu().numpy()
  211. scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
  212. landmarks = decode_landm(landmarks.squeeze(0), priors, self.cfg["variance"])
  213. landmarks = landmarks * self.scale1 / self.resize
  214. landmarks = landmarks.cpu().numpy()
  215. # ignore low scores
  216. inds = np.where(scores > conf_threshold)[0]
  217. boxes, landmarks, scores = boxes[inds], landmarks[inds], scores[inds]
  218. # sort
  219. order = scores.argsort()[::-1]
  220. boxes, landmarks, scores = boxes[order], landmarks[order], scores[order]
  221. # do NMS
  222. bounding_boxes = np.hstack((boxes, scores[:, np.newaxis])).astype(
  223. np.float32, copy=False
  224. )
  225. keep = py_cpu_nms(bounding_boxes, nms_threshold)
  226. bounding_boxes, landmarks = bounding_boxes[keep, :], landmarks[keep]
  227. # self.t['forward_pass'].toc()
  228. # print(self.t['forward_pass'].average_time)
  229. # import sys
  230. # sys.stdout.flush()
  231. return np.concatenate((bounding_boxes, landmarks), axis=1)
  232. def __align_multi(self, image, boxes, landmarks, limit=None):
  233. if len(boxes) < 1:
  234. return [], []
  235. if limit:
  236. boxes = boxes[:limit]
  237. landmarks = landmarks[:limit]
  238. faces = []
  239. for landmark in landmarks:
  240. facial5points = [[landmark[2 * j], landmark[2 * j + 1]] for j in range(5)]
  241. warped_face = warp_and_crop_face(
  242. np.array(image), facial5points, self.reference, crop_size=(112, 112)
  243. )
  244. faces.append(warped_face)
  245. return np.concatenate((boxes, landmarks), axis=1), faces
  246. def align_multi(self, img, conf_threshold=0.8, limit=None):
  247. rlt = self.detect_faces(img, conf_threshold=conf_threshold)
  248. boxes, landmarks = rlt[:, 0:5], rlt[:, 5:]
  249. return self.__align_multi(img, boxes, landmarks, limit)
  250. # batched detection
  251. def batched_transform(self, frames, use_origin_size):
  252. """
  253. Arguments:
  254. frames: a list of PIL.Image, or torch.Tensor(shape=[n, h, w, c],
  255. type=np.float32, BGR format).
  256. use_origin_size: whether to use origin size.
  257. """
  258. from_PIL = True if isinstance(frames[0], Image.Image) else False
  259. # convert to opencv format
  260. if from_PIL:
  261. frames = [
  262. cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2BGR) for frame in frames
  263. ]
  264. frames = np.asarray(frames, dtype=np.float32)
  265. # testing scale
  266. im_size_min = np.min(frames[0].shape[0:2])
  267. im_size_max = np.max(frames[0].shape[0:2])
  268. resize = float(self.target_size) / float(im_size_min)
  269. # prevent bigger axis from being more than max_size
  270. if np.round(resize * im_size_max) > self.max_size:
  271. resize = float(self.max_size) / float(im_size_max)
  272. resize = 1 if use_origin_size else resize
  273. # resize
  274. if resize != 1:
  275. if not from_PIL:
  276. frames = F.interpolate(frames, scale_factor=resize)
  277. else:
  278. frames = [
  279. cv2.resize(
  280. frame,
  281. None,
  282. None,
  283. fx=resize,
  284. fy=resize,
  285. interpolation=cv2.INTER_LINEAR,
  286. )
  287. for frame in frames
  288. ]
  289. # convert to torch.tensor format
  290. if not from_PIL:
  291. frames = frames.transpose(1, 2).transpose(1, 3).contiguous()
  292. else:
  293. frames = frames.transpose((0, 3, 1, 2))
  294. frames = torch.from_numpy(frames)
  295. return frames, resize
  296. def batched_detect_faces(
  297. self, frames, conf_threshold=0.8, nms_threshold=0.4, use_origin_size=True
  298. ):
  299. """
  300. Arguments:
  301. frames: a list of PIL.Image, or np.array(shape=[n, h, w, c],
  302. type=np.uint8, BGR format).
  303. conf_threshold: confidence threshold.
  304. nms_threshold: nms threshold.
  305. use_origin_size: whether to use origin size.
  306. Returns:
  307. final_bounding_boxes: list of np.array ([n_boxes, 5],
  308. type=np.float32).
  309. final_landmarks: list of np.array ([n_boxes, 10], type=np.float32).
  310. """
  311. # self.t['forward_pass'].tic()
  312. frames, self.resize = self.batched_transform(frames, use_origin_size)
  313. frames = frames.to(self.device)
  314. frames = frames - self.mean_tensor
  315. b_loc, b_conf, b_landmarks, priors = self.__detect_faces(frames)
  316. final_bounding_boxes, final_landmarks = [], []
  317. # decode
  318. priors = priors.unsqueeze(0)
  319. b_loc = (
  320. batched_decode(b_loc, priors, self.cfg["variance"])
  321. * self.scale
  322. / self.resize
  323. )
  324. b_landmarks = (
  325. batched_decode_landm(b_landmarks, priors, self.cfg["variance"])
  326. * self.scale1
  327. / self.resize
  328. )
  329. b_conf = b_conf[:, :, 1]
  330. # index for selection
  331. b_indice = b_conf > conf_threshold
  332. # concat
  333. b_loc_and_conf = torch.cat((b_loc, b_conf.unsqueeze(-1)), dim=2).float()
  334. for pred, landm, inds in zip(b_loc_and_conf, b_landmarks, b_indice):
  335. # ignore low scores
  336. pred, landm = pred[inds, :], landm[inds, :]
  337. if pred.shape[0] == 0:
  338. final_bounding_boxes.append(np.array([], dtype=np.float32))
  339. final_landmarks.append(np.array([], dtype=np.float32))
  340. continue
  341. # sort
  342. # order = score.argsort(descending=True)
  343. # box, landm, score = box[order], landm[order], score[order]
  344. # to CPU
  345. bounding_boxes, landm = pred.cpu().numpy(), landm.cpu().numpy()
  346. # NMS
  347. keep = py_cpu_nms(bounding_boxes, nms_threshold)
  348. bounding_boxes, landmarks = bounding_boxes[keep, :], landm[keep]
  349. # append
  350. final_bounding_boxes.append(bounding_boxes)
  351. final_landmarks.append(landmarks)
  352. # self.t['forward_pass'].toc(average=True)
  353. # self.batch_time += self.t['forward_pass'].diff
  354. # self.total_frame += len(frames)
  355. # print(self.batch_time / self.total_frame)
  356. return final_bounding_boxes, final_landmarks