manga.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import os
  2. import random
  3. import time
  4. import cv2
  5. import numpy as np
  6. import torch
  7. from loguru import logger
  8. from sorawm.iopaint.helper import download_model, get_cache_path_by_url, load_jit_model
  9. from sorawm.iopaint.schema import InpaintRequest
  10. from .base import InpaintModel
  11. MANGA_INPAINTOR_MODEL_URL = os.environ.get(
  12. "MANGA_INPAINTOR_MODEL_URL",
  13. "https://github.com/Sanster/models/releases/download/manga/manga_inpaintor.jit",
  14. )
  15. MANGA_INPAINTOR_MODEL_MD5 = os.environ.get(
  16. "MANGA_INPAINTOR_MODEL_MD5", "7d8b269c4613b6b3768af714610da86c"
  17. )
  18. MANGA_LINE_MODEL_URL = os.environ.get(
  19. "MANGA_LINE_MODEL_URL",
  20. "https://github.com/Sanster/models/releases/download/manga/erika.jit",
  21. )
  22. MANGA_LINE_MODEL_MD5 = os.environ.get(
  23. "MANGA_LINE_MODEL_MD5", "0c926d5a4af8450b0d00bc5b9a095644"
  24. )
  25. class Manga(InpaintModel):
  26. name = "manga"
  27. pad_mod = 16
  28. is_erase_model = True
  29. def init_model(self, device, **kwargs):
  30. self.inpaintor_model = load_jit_model(
  31. MANGA_INPAINTOR_MODEL_URL, device, MANGA_INPAINTOR_MODEL_MD5
  32. )
  33. self.line_model = load_jit_model(
  34. MANGA_LINE_MODEL_URL, device, MANGA_LINE_MODEL_MD5
  35. )
  36. self.seed = 42
  37. @staticmethod
  38. def download():
  39. download_model(MANGA_INPAINTOR_MODEL_URL, MANGA_INPAINTOR_MODEL_MD5)
  40. download_model(MANGA_LINE_MODEL_URL, MANGA_LINE_MODEL_MD5)
  41. @staticmethod
  42. def is_downloaded() -> bool:
  43. model_paths = [
  44. get_cache_path_by_url(MANGA_INPAINTOR_MODEL_URL),
  45. get_cache_path_by_url(MANGA_LINE_MODEL_URL),
  46. ]
  47. return all([os.path.exists(it) for it in model_paths])
  48. def forward(self, image, mask, config: InpaintRequest):
  49. """
  50. image: [H, W, C] RGB
  51. mask: [H, W, 1]
  52. return: BGR IMAGE
  53. """
  54. seed = self.seed
  55. random.seed(seed)
  56. np.random.seed(seed)
  57. torch.manual_seed(seed)
  58. torch.cuda.manual_seed_all(seed)
  59. gray_img = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
  60. gray_img = torch.from_numpy(
  61. gray_img[np.newaxis, np.newaxis, :, :].astype(np.float32)
  62. ).to(self.device)
  63. start = time.time()
  64. lines = self.line_model(gray_img)
  65. torch.cuda.empty_cache()
  66. lines = torch.clamp(lines, 0, 255)
  67. logger.info(f"erika_model time: {time.time() - start}")
  68. mask = torch.from_numpy(mask[np.newaxis, :, :, :]).to(self.device)
  69. mask = mask.permute(0, 3, 1, 2)
  70. mask = torch.where(mask > 0.5, 1.0, 0.0)
  71. noise = torch.randn_like(mask)
  72. ones = torch.ones_like(mask)
  73. gray_img = gray_img / 255 * 2 - 1.0
  74. lines = lines / 255 * 2 - 1.0
  75. start = time.time()
  76. inpainted_image = self.inpaintor_model(gray_img, lines, mask, noise, ones)
  77. logger.info(f"image_inpaintor_model time: {time.time() - start}")
  78. cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
  79. cur_res = (cur_res * 127.5 + 127.5).astype(np.uint8)
  80. cur_res = cv2.cvtColor(cur_res, cv2.COLOR_GRAY2BGR)
  81. return cur_res