mi_gan.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import os
  2. import cv2
  3. import torch
  4. from sorawm.iopaint.helper import (
  5. boxes_from_mask,
  6. download_model,
  7. get_cache_path_by_url,
  8. load_jit_model,
  9. norm_img,
  10. resize_max_size,
  11. )
  12. from sorawm.iopaint.schema import InpaintRequest
  13. from .base import InpaintModel
  14. MIGAN_MODEL_URL = os.environ.get(
  15. "MIGAN_MODEL_URL",
  16. "https://github.com/Sanster/models/releases/download/migan/migan_traced.pt",
  17. )
  18. MIGAN_MODEL_MD5 = os.environ.get("MIGAN_MODEL_MD5", "76eb3b1a71c400ee3290524f7a11b89c")
  19. class MIGAN(InpaintModel):
  20. name = "migan"
  21. min_size = 512
  22. pad_mod = 512
  23. pad_to_square = True
  24. is_erase_model = True
  25. def init_model(self, device, **kwargs):
  26. self.model = load_jit_model(MIGAN_MODEL_URL, device, MIGAN_MODEL_MD5).eval()
  27. @staticmethod
  28. def download():
  29. download_model(MIGAN_MODEL_URL, MIGAN_MODEL_MD5)
  30. @staticmethod
  31. def is_downloaded() -> bool:
  32. return os.path.exists(get_cache_path_by_url(MIGAN_MODEL_URL))
  33. @torch.no_grad()
  34. def __call__(self, image, mask, config: InpaintRequest):
  35. """
  36. images: [H, W, C] RGB, not normalized
  37. masks: [H, W]
  38. return: BGR IMAGE
  39. """
  40. if image.shape[0] == 512 and image.shape[1] == 512:
  41. return self._pad_forward(image, mask, config)
  42. boxes = boxes_from_mask(mask)
  43. crop_result = []
  44. config.hd_strategy_crop_margin = 128
  45. for box in boxes:
  46. crop_image, crop_mask, crop_box = self._crop_box(image, mask, box, config)
  47. origin_size = crop_image.shape[:2]
  48. resize_image = resize_max_size(crop_image, size_limit=512)
  49. resize_mask = resize_max_size(crop_mask, size_limit=512)
  50. inpaint_result = self._pad_forward(resize_image, resize_mask, config)
  51. # only paste masked area result
  52. inpaint_result = cv2.resize(
  53. inpaint_result,
  54. (origin_size[1], origin_size[0]),
  55. interpolation=cv2.INTER_CUBIC,
  56. )
  57. original_pixel_indices = crop_mask < 127
  58. inpaint_result[original_pixel_indices] = crop_image[:, :, ::-1][
  59. original_pixel_indices
  60. ]
  61. crop_result.append((inpaint_result, crop_box))
  62. inpaint_result = image[:, :, ::-1].copy()
  63. for crop_image, crop_box in crop_result:
  64. x1, y1, x2, y2 = crop_box
  65. inpaint_result[y1:y2, x1:x2, :] = crop_image
  66. return inpaint_result
  67. def forward(self, image, mask, config: InpaintRequest):
  68. """Input images and output images have same size
  69. images: [H, W, C] RGB
  70. masks: [H, W] mask area == 255
  71. return: BGR IMAGE
  72. """
  73. image = norm_img(image) # [0, 1]
  74. image = image * 2 - 1 # [0, 1] -> [-1, 1]
  75. mask = (mask > 120) * 255
  76. mask = norm_img(mask)
  77. image = torch.from_numpy(image).unsqueeze(0).to(self.device)
  78. mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
  79. erased_img = image * (1 - mask)
  80. input_image = torch.cat([0.5 - mask, erased_img], dim=1)
  81. output = self.model(input_image)
  82. output = (
  83. (output.permute(0, 2, 3, 1) * 127.5 + 127.5)
  84. .round()
  85. .clamp(0, 255)
  86. .to(torch.uint8)
  87. )
  88. output = output[0].cpu().numpy()
  89. cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
  90. return cur_res