restoreformer.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import cv2
  2. import numpy as np
  3. from loguru import logger
  4. from sorawm.iopaint.helper import download_model
  5. from sorawm.iopaint.plugins.base_plugin import BasePlugin
  6. from sorawm.iopaint.schema import RunPluginRequest
  7. class RestoreFormerPlugin(BasePlugin):
  8. name = "RestoreFormer"
  9. support_gen_image = True
  10. def __init__(self, device, upscaler=None):
  11. super().__init__()
  12. from .gfpganer import MyGFPGANer
  13. url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth"
  14. model_md5 = "eaeeff6c4a1caa1673977cb374e6f699"
  15. model_path = download_model(url, model_md5)
  16. logger.info(f"RestoreFormer model path: {model_path}")
  17. self.face_enhancer = MyGFPGANer(
  18. model_path=model_path,
  19. upscale=1,
  20. arch="RestoreFormer",
  21. channel_multiplier=2,
  22. device=device,
  23. bg_upsampler=upscaler.model if upscaler is not None else None,
  24. )
  25. def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
  26. weight = 0.5
  27. bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
  28. logger.info(f"RestoreFormer input shape: {bgr_np_img.shape}")
  29. _, _, bgr_output = self.face_enhancer.enhance(
  30. bgr_np_img,
  31. has_aligned=False,
  32. only_center_face=False,
  33. paste_back=True,
  34. weight=weight,
  35. )
  36. logger.info(f"RestoreFormer output shape: {bgr_output.shape}")
  37. return bgr_output