base.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. import abc
  2. from typing import Optional
  3. import cv2
  4. import numpy as np
  5. import torch
  6. from loguru import logger
  7. from sorawm.iopaint.helper import (
  8. boxes_from_mask,
  9. pad_img_to_modulo,
  10. resize_max_size,
  11. switch_mps_device,
  12. )
  13. from sorawm.iopaint.schema import HDStrategy, InpaintRequest, SDSampler
  14. from .helper.g_diffuser_bot import expand_image
  15. from .utils import get_scheduler
  16. class InpaintModel:
  17. name = "base"
  18. min_size: Optional[int] = None
  19. pad_mod = 8
  20. pad_to_square = False
  21. is_erase_model = False
  22. def __init__(self, device, **kwargs):
  23. """
  24. Args:
  25. device:
  26. """
  27. device = switch_mps_device(self.name, device)
  28. self.device = device
  29. self.init_model(device, **kwargs)
  30. @abc.abstractmethod
  31. def init_model(self, device, **kwargs):
  32. ...
  33. @staticmethod
  34. @abc.abstractmethod
  35. def is_downloaded() -> bool:
  36. return False
  37. @abc.abstractmethod
  38. def forward(self, image, mask, config: InpaintRequest):
  39. """Input images and output images have same size
  40. images: [H, W, C] RGB
  41. masks: [H, W, 1] 255 为 masks 区域
  42. return: BGR IMAGE
  43. """
  44. ...
  45. @staticmethod
  46. def download():
  47. ...
  48. def _pad_forward(self, image, mask, config: InpaintRequest):
  49. origin_height, origin_width = image.shape[:2]
  50. pad_image = pad_img_to_modulo(
  51. image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
  52. )
  53. pad_mask = pad_img_to_modulo(
  54. mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
  55. )
  56. # logger.info(f"final forward pad size: {pad_image.shape}")
  57. image, mask = self.forward_pre_process(image, mask, config)
  58. result = self.forward(pad_image, pad_mask, config)
  59. result = result[0:origin_height, 0:origin_width, :]
  60. result, image, mask = self.forward_post_process(result, image, mask, config)
  61. if config.sd_keep_unmasked_area:
  62. mask = mask[:, :, np.newaxis]
  63. result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255))
  64. return result
  65. def forward_pre_process(self, image, mask, config):
  66. return image, mask
  67. def forward_post_process(self, result, image, mask, config):
  68. return result, image, mask
  69. @torch.no_grad()
  70. def __call__(self, image, mask, config: InpaintRequest):
  71. """
  72. images: [H, W, C] RGB, not normalized
  73. masks: [H, W]
  74. return: BGR IMAGE
  75. """
  76. inpaint_result = None
  77. # logger.info(f"hd_strategy: {config.hd_strategy}")
  78. if config.hd_strategy == HDStrategy.CROP:
  79. if max(image.shape) > config.hd_strategy_crop_trigger_size:
  80. # logger.info("Run crop strategy")
  81. boxes = boxes_from_mask(mask)
  82. crop_result = []
  83. for box in boxes:
  84. crop_image, crop_box = self._run_box(image, mask, box, config)
  85. crop_result.append((crop_image, crop_box))
  86. inpaint_result = image[:, :, ::-1].copy()
  87. for crop_image, crop_box in crop_result:
  88. x1, y1, x2, y2 = crop_box
  89. inpaint_result[y1:y2, x1:x2, :] = crop_image
  90. elif config.hd_strategy == HDStrategy.RESIZE:
  91. if max(image.shape) > config.hd_strategy_resize_limit:
  92. origin_size = image.shape[:2]
  93. downsize_image = resize_max_size(
  94. image, size_limit=config.hd_strategy_resize_limit
  95. )
  96. downsize_mask = resize_max_size(
  97. mask, size_limit=config.hd_strategy_resize_limit
  98. )
  99. logger.info(
  100. f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}"
  101. )
  102. inpaint_result = self._pad_forward(
  103. downsize_image, downsize_mask, config
  104. )
  105. # only paste masked area result
  106. inpaint_result = cv2.resize(
  107. inpaint_result,
  108. (origin_size[1], origin_size[0]),
  109. interpolation=cv2.INTER_CUBIC,
  110. )
  111. original_pixel_indices = mask < 127
  112. inpaint_result[original_pixel_indices] = image[:, :, ::-1][
  113. original_pixel_indices
  114. ]
  115. if inpaint_result is None:
  116. inpaint_result = self._pad_forward(image, mask, config)
  117. return inpaint_result
  118. def _crop_box(self, image, mask, box, config: InpaintRequest):
  119. """
  120. Args:
  121. image: [H, W, C] RGB
  122. mask: [H, W, 1]
  123. box: [left,top,right,bottom]
  124. Returns:
  125. BGR IMAGE, (l, r, r, b)
  126. """
  127. box_h = box[3] - box[1]
  128. box_w = box[2] - box[0]
  129. cx = (box[0] + box[2]) // 2
  130. cy = (box[1] + box[3]) // 2
  131. img_h, img_w = image.shape[:2]
  132. w = box_w + config.hd_strategy_crop_margin * 2
  133. h = box_h + config.hd_strategy_crop_margin * 2
  134. _l = cx - w // 2
  135. _r = cx + w // 2
  136. _t = cy - h // 2
  137. _b = cy + h // 2
  138. l = max(_l, 0)
  139. r = min(_r, img_w)
  140. t = max(_t, 0)
  141. b = min(_b, img_h)
  142. # try to get more context when crop around image edge
  143. if _l < 0:
  144. r += abs(_l)
  145. if _r > img_w:
  146. l -= _r - img_w
  147. if _t < 0:
  148. b += abs(_t)
  149. if _b > img_h:
  150. t -= _b - img_h
  151. l = max(l, 0)
  152. r = min(r, img_w)
  153. t = max(t, 0)
  154. b = min(b, img_h)
  155. crop_img = image[t:b, l:r, :]
  156. crop_mask = mask[t:b, l:r]
  157. # logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
  158. return crop_img, crop_mask, [l, t, r, b]
  159. def _calculate_cdf(self, histogram):
  160. cdf = histogram.cumsum()
  161. normalized_cdf = cdf / float(cdf.max())
  162. return normalized_cdf
  163. def _calculate_lookup(self, source_cdf, reference_cdf):
  164. lookup_table = np.zeros(256)
  165. lookup_val = 0
  166. for source_index, source_val in enumerate(source_cdf):
  167. for reference_index, reference_val in enumerate(reference_cdf):
  168. if reference_val >= source_val:
  169. lookup_val = reference_index
  170. break
  171. lookup_table[source_index] = lookup_val
  172. return lookup_table
  173. def _match_histograms(self, source, reference, mask):
  174. transformed_channels = []
  175. if len(mask.shape) == 3:
  176. mask = mask[:, :, -1]
  177. for channel in range(source.shape[-1]):
  178. source_channel = source[:, :, channel]
  179. reference_channel = reference[:, :, channel]
  180. # only calculate histograms for non-masked parts
  181. source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0, 256])
  182. reference_histogram, _ = np.histogram(
  183. reference_channel[mask == 0], 256, [0, 256]
  184. )
  185. source_cdf = self._calculate_cdf(source_histogram)
  186. reference_cdf = self._calculate_cdf(reference_histogram)
  187. lookup = self._calculate_lookup(source_cdf, reference_cdf)
  188. transformed_channels.append(cv2.LUT(source_channel, lookup))
  189. result = cv2.merge(transformed_channels)
  190. result = cv2.convertScaleAbs(result)
  191. return result
  192. def _apply_cropper(self, image, mask, config: InpaintRequest):
  193. img_h, img_w = image.shape[:2]
  194. l, t, w, h = (
  195. config.croper_x,
  196. config.croper_y,
  197. config.croper_width,
  198. config.croper_height,
  199. )
  200. r = l + w
  201. b = t + h
  202. l = max(l, 0)
  203. r = min(r, img_w)
  204. t = max(t, 0)
  205. b = min(b, img_h)
  206. crop_img = image[t:b, l:r, :]
  207. crop_mask = mask[t:b, l:r]
  208. return crop_img, crop_mask, (l, t, r, b)
  209. def _run_box(self, image, mask, box, config: InpaintRequest):
  210. """
  211. Args:
  212. image: [H, W, C] RGB
  213. mask: [H, W, 1]
  214. box: [left,top,right,bottom]
  215. Returns:
  216. BGR IMAGE
  217. """
  218. crop_img, crop_mask, [l, t, r, b] = self._crop_box(image, mask, box, config)
  219. return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b]
  220. class DiffusionInpaintModel(InpaintModel):
  221. def __init__(self, device, **kwargs):
  222. self.model_info = kwargs["model_info"]
  223. self.model_id_or_path = self.model_info.path
  224. super().__init__(device, **kwargs)
  225. @torch.no_grad()
  226. def __call__(self, image, mask, config: InpaintRequest):
  227. """
  228. images: [H, W, C] RGB, not normalized
  229. masks: [H, W]
  230. return: BGR IMAGE
  231. """
  232. # boxes = boxes_from_mask(mask)
  233. if config.use_croper:
  234. crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
  235. crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
  236. inpaint_result = image[:, :, ::-1].copy()
  237. inpaint_result[t:b, l:r, :] = crop_image
  238. elif config.use_extender:
  239. inpaint_result = self._do_outpainting(image, config)
  240. else:
  241. inpaint_result = self._scaled_pad_forward(image, mask, config)
  242. return inpaint_result
  243. def _do_outpainting(self, image, config: InpaintRequest):
  244. # cropper 和 image 在同一个坐标系下,croper_x/y 可能为负数
  245. # 从 image 中 crop 出 outpainting 区域
  246. image_h, image_w = image.shape[:2]
  247. cropper_l = config.extender_x
  248. cropper_t = config.extender_y
  249. cropper_r = config.extender_x + config.extender_width
  250. cropper_b = config.extender_y + config.extender_height
  251. image_l = 0
  252. image_t = 0
  253. image_r = image_w
  254. image_b = image_h
  255. # 类似求 IOU
  256. l = max(cropper_l, image_l)
  257. t = max(cropper_t, image_t)
  258. r = min(cropper_r, image_r)
  259. b = min(cropper_b, image_b)
  260. assert (
  261. 0 <= l < r and 0 <= t < b
  262. ), f"cropper and image not overlap, {l},{t},{r},{b}"
  263. cropped_image = image[t:b, l:r, :]
  264. padding_l = max(0, image_l - cropper_l)
  265. padding_t = max(0, image_t - cropper_t)
  266. padding_r = max(0, cropper_r - image_r)
  267. padding_b = max(0, cropper_b - image_b)
  268. expanded_image, mask_image = expand_image(
  269. cropped_image,
  270. left=padding_l,
  271. top=padding_t,
  272. right=padding_r,
  273. bottom=padding_b,
  274. )
  275. # 最终扩大了的 image, BGR
  276. expanded_cropped_result_image = self._scaled_pad_forward(
  277. expanded_image, mask_image, config
  278. )
  279. # RGB -> BGR
  280. outpainting_image = cv2.copyMakeBorder(
  281. image,
  282. left=padding_l,
  283. top=padding_t,
  284. right=padding_r,
  285. bottom=padding_b,
  286. borderType=cv2.BORDER_CONSTANT,
  287. value=0,
  288. )[:, :, ::-1]
  289. # 把 cropped_result_image 贴到 outpainting_image 上,这一步不需要 blend
  290. paste_t = 0 if config.extender_y < 0 else config.extender_y
  291. paste_l = 0 if config.extender_x < 0 else config.extender_x
  292. outpainting_image[
  293. paste_t : paste_t + expanded_cropped_result_image.shape[0],
  294. paste_l : paste_l + expanded_cropped_result_image.shape[1],
  295. :,
  296. ] = expanded_cropped_result_image
  297. return outpainting_image
  298. def _scaled_pad_forward(self, image, mask, config: InpaintRequest):
  299. longer_side_length = int(config.sd_scale * max(image.shape[:2]))
  300. origin_size = image.shape[:2]
  301. downsize_image = resize_max_size(image, size_limit=longer_side_length)
  302. downsize_mask = resize_max_size(mask, size_limit=longer_side_length)
  303. if config.sd_scale != 1:
  304. logger.info(
  305. f"Resize image to do sd inpainting: {image.shape} -> {downsize_image.shape}"
  306. )
  307. inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
  308. # only paste masked area result
  309. inpaint_result = cv2.resize(
  310. inpaint_result,
  311. (origin_size[1], origin_size[0]),
  312. interpolation=cv2.INTER_CUBIC,
  313. )
  314. return inpaint_result
  315. def set_scheduler(self, config: InpaintRequest):
  316. scheduler_config = self.model.scheduler.config
  317. sd_sampler = config.sd_sampler
  318. if config.sd_lcm_lora and self.model_info.support_lcm_lora:
  319. sd_sampler = SDSampler.lcm
  320. logger.info(f"LCM Lora enabled, use {sd_sampler} sampler")
  321. scheduler = get_scheduler(sd_sampler, scheduler_config)
  322. self.model.scheduler = scheduler
  323. def forward_pre_process(self, image, mask, config):
  324. if config.sd_mask_blur != 0:
  325. k = 2 * config.sd_mask_blur + 1
  326. mask = cv2.GaussianBlur(mask, (k, k), 0)
  327. return image, mask
  328. def forward_post_process(self, result, image, mask, config):
  329. if config.sd_match_histograms:
  330. result = self._match_histograms(result, image[:, :, ::-1], mask)
  331. if config.use_extender and config.sd_mask_blur != 0:
  332. k = 2 * config.sd_mask_blur + 1
  333. mask = cv2.GaussianBlur(mask, (k, k), 0)
  334. return result, image, mask