gfpgan_plugin.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import cv2
  2. import numpy as np
  3. from loguru import logger
  4. from sorawm.iopaint.helper import download_model
  5. from sorawm.iopaint.plugins.base_plugin import BasePlugin
  6. from sorawm.iopaint.schema import RunPluginRequest
  7. class GFPGANPlugin(BasePlugin):
  8. name = "GFPGAN"
  9. support_gen_image = True
  10. def __init__(self, device, upscaler=None):
  11. super().__init__()
  12. from .gfpganer import MyGFPGANer
  13. url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
  14. model_md5 = "94d735072630ab734561130a47bc44f8"
  15. model_path = download_model(url, model_md5)
  16. logger.info(f"GFPGAN model path: {model_path}")
  17. # Use GFPGAN for face enhancement
  18. self.face_enhancer = MyGFPGANer(
  19. model_path=model_path,
  20. upscale=1,
  21. arch="clean",
  22. channel_multiplier=2,
  23. device=device,
  24. bg_upsampler=upscaler.model if upscaler is not None else None,
  25. )
  26. self.face_enhancer.face_helper.face_det.mean_tensor.to(device)
  27. self.face_enhancer.face_helper.face_det = (
  28. self.face_enhancer.face_helper.face_det.to(device)
  29. )
  30. def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
  31. weight = 0.5
  32. bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
  33. logger.info(f"GFPGAN input shape: {bgr_np_img.shape}")
  34. _, _, bgr_output = self.face_enhancer.enhance(
  35. bgr_np_img,
  36. has_aligned=False,
  37. only_center_face=False,
  38. paste_back=True,
  39. weight=weight,
  40. )
  41. logger.info(f"GFPGAN output shape: {bgr_output.shape}")
  42. # try:
  43. # if scale != 2:
  44. # interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
  45. # h, w = img.shape[0:2]
  46. # output = cv2.resize(
  47. # output,
  48. # (int(w * scale / 2), int(h * scale / 2)),
  49. # interpolation=interpolation,
  50. # )
  51. # except Exception as error:
  52. # print("wrong scale input.", error)
  53. return bgr_output