zits.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. import os
  2. import time
  3. import cv2
  4. import numpy as np
  5. import torch
  6. import torch.nn.functional as F
  7. from sorawm.iopaint.helper import download_model, get_cache_path_by_url, load_jit_model
  8. from sorawm.iopaint.schema import InpaintRequest
  9. from .base import InpaintModel
  10. ZITS_INPAINT_MODEL_URL = os.environ.get(
  11. "ZITS_INPAINT_MODEL_URL",
  12. "https://github.com/Sanster/models/releases/download/add_zits/zits-inpaint-0717.pt",
  13. )
  14. ZITS_INPAINT_MODEL_MD5 = os.environ.get(
  15. "ZITS_INPAINT_MODEL_MD5", "9978cc7157dc29699e42308d675b2154"
  16. )
  17. ZITS_EDGE_LINE_MODEL_URL = os.environ.get(
  18. "ZITS_EDGE_LINE_MODEL_URL",
  19. "https://github.com/Sanster/models/releases/download/add_zits/zits-edge-line-0717.pt",
  20. )
  21. ZITS_EDGE_LINE_MODEL_MD5 = os.environ.get(
  22. "ZITS_EDGE_LINE_MODEL_MD5", "55e31af21ba96bbf0c80603c76ea8c5f"
  23. )
  24. ZITS_STRUCTURE_UPSAMPLE_MODEL_URL = os.environ.get(
  25. "ZITS_STRUCTURE_UPSAMPLE_MODEL_URL",
  26. "https://github.com/Sanster/models/releases/download/add_zits/zits-structure-upsample-0717.pt",
  27. )
  28. ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5 = os.environ.get(
  29. "ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5", "3d88a07211bd41b2ec8cc0d999f29927"
  30. )
  31. ZITS_WIRE_FRAME_MODEL_URL = os.environ.get(
  32. "ZITS_WIRE_FRAME_MODEL_URL",
  33. "https://github.com/Sanster/models/releases/download/add_zits/zits-wireframe-0717.pt",
  34. )
  35. ZITS_WIRE_FRAME_MODEL_MD5 = os.environ.get(
  36. "ZITS_WIRE_FRAME_MODEL_MD5", "a9727c63a8b48b65c905d351b21ce46b"
  37. )
  38. def resize(img, height, width, center_crop=False):
  39. imgh, imgw = img.shape[0:2]
  40. if center_crop and imgh != imgw:
  41. # center crop
  42. side = np.minimum(imgh, imgw)
  43. j = (imgh - side) // 2
  44. i = (imgw - side) // 2
  45. img = img[j : j + side, i : i + side, ...]
  46. if imgh > height and imgw > width:
  47. inter = cv2.INTER_AREA
  48. else:
  49. inter = cv2.INTER_LINEAR
  50. img = cv2.resize(img, (height, width), interpolation=inter)
  51. return img
  52. def to_tensor(img, scale=True, norm=False):
  53. if img.ndim == 2:
  54. img = img[:, :, np.newaxis]
  55. c = img.shape[-1]
  56. if scale:
  57. img_t = torch.from_numpy(img).permute(2, 0, 1).float().div(255)
  58. else:
  59. img_t = torch.from_numpy(img).permute(2, 0, 1).float()
  60. if norm:
  61. mean = torch.tensor([0.5, 0.5, 0.5]).reshape(c, 1, 1)
  62. std = torch.tensor([0.5, 0.5, 0.5]).reshape(c, 1, 1)
  63. img_t = (img_t - mean) / std
  64. return img_t
  65. def load_masked_position_encoding(mask):
  66. ones_filter = np.ones((3, 3), dtype=np.float32)
  67. d_filter1 = np.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]], dtype=np.float32)
  68. d_filter2 = np.array([[0, 0, 0], [1, 1, 0], [1, 1, 0]], dtype=np.float32)
  69. d_filter3 = np.array([[0, 1, 1], [0, 1, 1], [0, 0, 0]], dtype=np.float32)
  70. d_filter4 = np.array([[0, 0, 0], [0, 1, 1], [0, 1, 1]], dtype=np.float32)
  71. str_size = 256
  72. pos_num = 128
  73. ori_mask = mask.copy()
  74. ori_h, ori_w = ori_mask.shape[0:2]
  75. ori_mask = ori_mask / 255
  76. mask = cv2.resize(mask, (str_size, str_size), interpolation=cv2.INTER_AREA)
  77. mask[mask > 0] = 255
  78. h, w = mask.shape[0:2]
  79. mask3 = mask.copy()
  80. mask3 = 1.0 - (mask3 / 255.0)
  81. pos = np.zeros((h, w), dtype=np.int32)
  82. direct = np.zeros((h, w, 4), dtype=np.int32)
  83. i = 0
  84. while np.sum(1 - mask3) > 0:
  85. i += 1
  86. mask3_ = cv2.filter2D(mask3, -1, ones_filter)
  87. mask3_[mask3_ > 0] = 1
  88. sub_mask = mask3_ - mask3
  89. pos[sub_mask == 1] = i
  90. m = cv2.filter2D(mask3, -1, d_filter1)
  91. m[m > 0] = 1
  92. m = m - mask3
  93. direct[m == 1, 0] = 1
  94. m = cv2.filter2D(mask3, -1, d_filter2)
  95. m[m > 0] = 1
  96. m = m - mask3
  97. direct[m == 1, 1] = 1
  98. m = cv2.filter2D(mask3, -1, d_filter3)
  99. m[m > 0] = 1
  100. m = m - mask3
  101. direct[m == 1, 2] = 1
  102. m = cv2.filter2D(mask3, -1, d_filter4)
  103. m[m > 0] = 1
  104. m = m - mask3
  105. direct[m == 1, 3] = 1
  106. mask3 = mask3_
  107. abs_pos = pos.copy()
  108. rel_pos = pos / (str_size / 2) # to 0~1 maybe larger than 1
  109. rel_pos = (rel_pos * pos_num).astype(np.int32)
  110. rel_pos = np.clip(rel_pos, 0, pos_num - 1)
  111. if ori_w != w or ori_h != h:
  112. rel_pos = cv2.resize(rel_pos, (ori_w, ori_h), interpolation=cv2.INTER_NEAREST)
  113. rel_pos[ori_mask == 0] = 0
  114. direct = cv2.resize(direct, (ori_w, ori_h), interpolation=cv2.INTER_NEAREST)
  115. direct[ori_mask == 0, :] = 0
  116. return rel_pos, abs_pos, direct
  117. def load_image(img, mask, device, sigma256=3.0):
  118. """
  119. Args:
  120. img: [H, W, C] RGB
  121. mask: [H, W] 255 为 masks 区域
  122. sigma256:
  123. Returns:
  124. """
  125. h, w, _ = img.shape
  126. imgh, imgw = img.shape[0:2]
  127. img_256 = resize(img, 256, 256)
  128. mask = (mask > 127).astype(np.uint8) * 255
  129. mask_256 = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_AREA)
  130. mask_256[mask_256 > 0] = 255
  131. mask_512 = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_AREA)
  132. mask_512[mask_512 > 0] = 255
  133. # original skimage implemention
  134. # https://scikit-image.org/docs/stable/api/skimage.feature.html#skimage.feature.canny
  135. # low_threshold: Lower bound for hysteresis thresholding (linking edges). If None, low_threshold is set to 10% of dtype’s max.
  136. # high_threshold: Upper bound for hysteresis thresholding (linking edges). If None, high_threshold is set to 20% of dtype’s max.
  137. try:
  138. import skimage
  139. gray_256 = skimage.color.rgb2gray(img_256)
  140. edge_256 = skimage.feature.canny(gray_256, sigma=3.0, mask=None).astype(float)
  141. # cv2.imwrite("skimage_gray.jpg", (gray_256*255).astype(np.uint8))
  142. # cv2.imwrite("skimage_edge.jpg", (edge_256*255).astype(np.uint8))
  143. except:
  144. gray_256 = cv2.cvtColor(img_256, cv2.COLOR_RGB2GRAY)
  145. gray_256_blured = cv2.GaussianBlur(
  146. gray_256, ksize=(7, 7), sigmaX=sigma256, sigmaY=sigma256
  147. )
  148. edge_256 = cv2.Canny(
  149. gray_256_blured, threshold1=int(255 * 0.1), threshold2=int(255 * 0.2)
  150. )
  151. # cv2.imwrite("opencv_edge.jpg", edge_256)
  152. # line
  153. img_512 = resize(img, 512, 512)
  154. rel_pos, abs_pos, direct = load_masked_position_encoding(mask)
  155. batch = dict()
  156. batch["images"] = to_tensor(img.copy()).unsqueeze(0).to(device)
  157. batch["img_256"] = to_tensor(img_256, norm=True).unsqueeze(0).to(device)
  158. batch["masks"] = to_tensor(mask).unsqueeze(0).to(device)
  159. batch["mask_256"] = to_tensor(mask_256).unsqueeze(0).to(device)
  160. batch["mask_512"] = to_tensor(mask_512).unsqueeze(0).to(device)
  161. batch["edge_256"] = to_tensor(edge_256, scale=False).unsqueeze(0).to(device)
  162. batch["img_512"] = to_tensor(img_512).unsqueeze(0).to(device)
  163. batch["rel_pos"] = torch.LongTensor(rel_pos).unsqueeze(0).to(device)
  164. batch["abs_pos"] = torch.LongTensor(abs_pos).unsqueeze(0).to(device)
  165. batch["direct"] = torch.LongTensor(direct).unsqueeze(0).to(device)
  166. batch["h"] = imgh
  167. batch["w"] = imgw
  168. return batch
  169. def to_device(data, device):
  170. if isinstance(data, torch.Tensor):
  171. return data.to(device)
  172. if isinstance(data, dict):
  173. for key in data:
  174. if isinstance(data[key], torch.Tensor):
  175. data[key] = data[key].to(device)
  176. return data
  177. if isinstance(data, list):
  178. return [to_device(d, device) for d in data]
  179. class ZITS(InpaintModel):
  180. name = "zits"
  181. min_size = 256
  182. pad_mod = 32
  183. pad_to_square = True
  184. is_erase_model = True
  185. def __init__(self, device, **kwargs):
  186. """
  187. Args:
  188. device:
  189. """
  190. super().__init__(device)
  191. self.device = device
  192. self.sample_edge_line_iterations = 1
  193. def init_model(self, device, **kwargs):
  194. self.wireframe = load_jit_model(
  195. ZITS_WIRE_FRAME_MODEL_URL, device, ZITS_WIRE_FRAME_MODEL_MD5
  196. )
  197. self.edge_line = load_jit_model(
  198. ZITS_EDGE_LINE_MODEL_URL, device, ZITS_EDGE_LINE_MODEL_MD5
  199. )
  200. self.structure_upsample = load_jit_model(
  201. ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, device, ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5
  202. )
  203. self.inpaint = load_jit_model(
  204. ZITS_INPAINT_MODEL_URL, device, ZITS_INPAINT_MODEL_MD5
  205. )
  206. @staticmethod
  207. def download():
  208. download_model(ZITS_WIRE_FRAME_MODEL_URL, ZITS_WIRE_FRAME_MODEL_MD5)
  209. download_model(ZITS_EDGE_LINE_MODEL_URL, ZITS_EDGE_LINE_MODEL_MD5)
  210. download_model(
  211. ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5
  212. )
  213. download_model(ZITS_INPAINT_MODEL_URL, ZITS_INPAINT_MODEL_MD5)
  214. @staticmethod
  215. def is_downloaded() -> bool:
  216. model_paths = [
  217. get_cache_path_by_url(ZITS_WIRE_FRAME_MODEL_URL),
  218. get_cache_path_by_url(ZITS_EDGE_LINE_MODEL_URL),
  219. get_cache_path_by_url(ZITS_STRUCTURE_UPSAMPLE_MODEL_URL),
  220. get_cache_path_by_url(ZITS_INPAINT_MODEL_URL),
  221. ]
  222. return all([os.path.exists(it) for it in model_paths])
  223. def wireframe_edge_and_line(self, items, enable: bool):
  224. # 最终向 items 中添加 edge 和 line key
  225. if not enable:
  226. items["edge"] = torch.zeros_like(items["masks"])
  227. items["line"] = torch.zeros_like(items["masks"])
  228. return
  229. start = time.time()
  230. try:
  231. line_256 = self.wireframe_forward(
  232. items["img_512"],
  233. h=256,
  234. w=256,
  235. masks=items["mask_512"],
  236. mask_th=0.85,
  237. )
  238. except:
  239. line_256 = torch.zeros_like(items["mask_256"])
  240. print(f"wireframe_forward time: {(time.time() - start) * 1000:.2f}ms")
  241. # np_line = (line[0][0].numpy() * 255).astype(np.uint8)
  242. # cv2.imwrite("line.jpg", np_line)
  243. start = time.time()
  244. edge_pred, line_pred = self.sample_edge_line_logits(
  245. context=[items["img_256"], items["edge_256"], line_256],
  246. mask=items["mask_256"].clone(),
  247. iterations=self.sample_edge_line_iterations,
  248. add_v=0.05,
  249. mul_v=4,
  250. )
  251. print(f"sample_edge_line_logits time: {(time.time() - start) * 1000:.2f}ms")
  252. # np_edge_pred = (edge_pred[0][0].numpy() * 255).astype(np.uint8)
  253. # cv2.imwrite("edge_pred.jpg", np_edge_pred)
  254. # np_line_pred = (line_pred[0][0].numpy() * 255).astype(np.uint8)
  255. # cv2.imwrite("line_pred.jpg", np_line_pred)
  256. # exit()
  257. input_size = min(items["h"], items["w"])
  258. if input_size != 256 and input_size > 256:
  259. while edge_pred.shape[2] < input_size:
  260. edge_pred = self.structure_upsample(edge_pred)
  261. edge_pred = torch.sigmoid((edge_pred + 2) * 2)
  262. line_pred = self.structure_upsample(line_pred)
  263. line_pred = torch.sigmoid((line_pred + 2) * 2)
  264. edge_pred = F.interpolate(
  265. edge_pred,
  266. size=(input_size, input_size),
  267. mode="bilinear",
  268. align_corners=False,
  269. )
  270. line_pred = F.interpolate(
  271. line_pred,
  272. size=(input_size, input_size),
  273. mode="bilinear",
  274. align_corners=False,
  275. )
  276. # np_edge_pred = (edge_pred[0][0].numpy() * 255).astype(np.uint8)
  277. # cv2.imwrite("edge_pred_upsample.jpg", np_edge_pred)
  278. # np_line_pred = (line_pred[0][0].numpy() * 255).astype(np.uint8)
  279. # cv2.imwrite("line_pred_upsample.jpg", np_line_pred)
  280. # exit()
  281. items["edge"] = edge_pred.detach()
  282. items["line"] = line_pred.detach()
  283. @torch.no_grad()
  284. def forward(self, image, mask, config: InpaintRequest):
  285. """Input images and output images have same size
  286. images: [H, W, C] RGB
  287. masks: [H, W]
  288. return: BGR IMAGE
  289. """
  290. mask = mask[:, :, 0]
  291. items = load_image(image, mask, device=self.device)
  292. self.wireframe_edge_and_line(items, config.zits_wireframe)
  293. inpainted_image = self.inpaint(
  294. items["images"],
  295. items["masks"],
  296. items["edge"],
  297. items["line"],
  298. items["rel_pos"],
  299. items["direct"],
  300. )
  301. inpainted_image = inpainted_image * 255.0
  302. inpainted_image = (
  303. inpainted_image.cpu().permute(0, 2, 3, 1)[0].numpy().astype(np.uint8)
  304. )
  305. inpainted_image = inpainted_image[:, :, ::-1]
  306. # cv2.imwrite("inpainted.jpg", inpainted_image)
  307. # exit()
  308. return inpainted_image
  309. def wireframe_forward(self, images, h, w, masks, mask_th=0.925):
  310. lcnn_mean = torch.tensor([109.730, 103.832, 98.681]).reshape(1, 3, 1, 1)
  311. lcnn_std = torch.tensor([22.275, 22.124, 23.229]).reshape(1, 3, 1, 1)
  312. images = images * 255.0
  313. # the masks value of lcnn is 127.5
  314. masked_images = images * (1 - masks) + torch.ones_like(images) * masks * 127.5
  315. masked_images = (masked_images - lcnn_mean) / lcnn_std
  316. def to_int(x):
  317. return tuple(map(int, x))
  318. lines_tensor = []
  319. lmap = np.zeros((h, w))
  320. output_masked = self.wireframe(masked_images)
  321. output_masked = to_device(output_masked, "cpu")
  322. if output_masked["num_proposals"] == 0:
  323. lines_masked = []
  324. scores_masked = []
  325. else:
  326. lines_masked = output_masked["lines_pred"].numpy()
  327. lines_masked = [
  328. [line[1] * h, line[0] * w, line[3] * h, line[2] * w]
  329. for line in lines_masked
  330. ]
  331. scores_masked = output_masked["lines_score"].numpy()
  332. for line, score in zip(lines_masked, scores_masked):
  333. if score > mask_th:
  334. try:
  335. import skimage
  336. rr, cc, value = skimage.draw.line_aa(
  337. *to_int(line[0:2]), *to_int(line[2:4])
  338. )
  339. lmap[rr, cc] = np.maximum(lmap[rr, cc], value)
  340. except:
  341. cv2.line(
  342. lmap,
  343. to_int(line[0:2][::-1]),
  344. to_int(line[2:4][::-1]),
  345. (1, 1, 1),
  346. 1,
  347. cv2.LINE_AA,
  348. )
  349. lmap = np.clip(lmap * 255, 0, 255).astype(np.uint8)
  350. lines_tensor.append(to_tensor(lmap).unsqueeze(0))
  351. lines_tensor = torch.cat(lines_tensor, dim=0)
  352. return lines_tensor.detach().to(self.device)
  353. def sample_edge_line_logits(
  354. self, context, mask=None, iterations=1, add_v=0, mul_v=4
  355. ):
  356. [img, edge, line] = context
  357. img = img * (1 - mask)
  358. edge = edge * (1 - mask)
  359. line = line * (1 - mask)
  360. for i in range(iterations):
  361. edge_logits, line_logits = self.edge_line(img, edge, line, masks=mask)
  362. edge_pred = torch.sigmoid(edge_logits)
  363. line_pred = torch.sigmoid((line_logits + add_v) * mul_v)
  364. edge = edge + edge_pred * mask
  365. edge[edge >= 0.25] = 1
  366. edge[edge < 0.25] = 0
  367. line = line + line_pred * mask
  368. b, _, h, w = edge_pred.shape
  369. edge_pred = edge_pred.reshape(b, -1, 1)
  370. line_pred = line_pred.reshape(b, -1, 1)
  371. mask = mask.reshape(b, -1)
  372. edge_probs = torch.cat([1 - edge_pred, edge_pred], dim=-1)
  373. line_probs = torch.cat([1 - line_pred, line_pred], dim=-1)
  374. edge_probs[:, :, 1] += 0.5
  375. line_probs[:, :, 1] += 0.5
  376. edge_max_probs = edge_probs.max(dim=-1)[0] + (1 - mask) * (-100)
  377. line_max_probs = line_probs.max(dim=-1)[0] + (1 - mask) * (-100)
  378. indices = torch.sort(
  379. edge_max_probs + line_max_probs, dim=-1, descending=True
  380. )[1]
  381. for ii in range(b):
  382. keep = int((i + 1) / iterations * torch.sum(mask[ii, ...]))
  383. assert torch.sum(mask[ii][indices[ii, :keep]]) == keep, "Error!!!"
  384. mask[ii][indices[ii, :keep]] = 0
  385. mask = mask.reshape(b, 1, h, w)
  386. edge = edge * (1 - mask)
  387. line = line * (1 - mask)
  388. edge, line = edge.to(torch.float32), line.to(torch.float32)
  389. return edge, line