gfpganer.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import os
  2. import cv2
  3. import torch
  4. from torch.hub import get_dir
  5. from torchvision.transforms.functional import normalize
  6. from .basicsr.img_util import img2tensor, tensor2img
  7. from .facexlib.utils.face_restoration_helper import FaceRestoreHelper
  8. from .gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
  9. class MyGFPGANer:
  10. """Helper for restoration with GFPGAN.
  11. It will detect and crop faces, and then resize the faces to 512x512.
  12. GFPGAN is used to restored the resized faces.
  13. The background is upsampled with the bg_upsampler.
  14. Finally, the faces will be pasted back to the upsample background image.
  15. Args:
  16. model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
  17. upscale (float): The upscale of the final output. Default: 2.
  18. arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
  19. channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
  20. bg_upsampler (nn.Module): The upsampler for the background. Default: None.
  21. """
  22. def __init__(
  23. self,
  24. model_path,
  25. upscale=2,
  26. arch="clean",
  27. channel_multiplier=2,
  28. bg_upsampler=None,
  29. device=None,
  30. ):
  31. self.upscale = upscale
  32. self.bg_upsampler = bg_upsampler
  33. # initialize model
  34. self.device = (
  35. torch.device("cuda" if torch.cuda.is_available() else "cpu")
  36. if device is None
  37. else device
  38. )
  39. # initialize the GFP-GAN
  40. if arch == "clean":
  41. self.gfpgan = GFPGANv1Clean(
  42. out_size=512,
  43. num_style_feat=512,
  44. channel_multiplier=channel_multiplier,
  45. decoder_load_path=None,
  46. fix_decoder=False,
  47. num_mlp=8,
  48. input_is_latent=True,
  49. different_w=True,
  50. narrow=1,
  51. sft_half=True,
  52. )
  53. elif arch == "RestoreFormer":
  54. from .gfpgan.archs.restoreformer_arch import RestoreFormer
  55. self.gfpgan = RestoreFormer()
  56. hub_dir = get_dir()
  57. model_dir = os.path.join(hub_dir, "checkpoints")
  58. # initialize face helper
  59. self.face_helper = FaceRestoreHelper(
  60. upscale,
  61. face_size=512,
  62. crop_ratio=(1, 1),
  63. det_model="retinaface_resnet50",
  64. save_ext="png",
  65. use_parse=True,
  66. device=self.device,
  67. model_rootpath=model_dir,
  68. )
  69. loadnet = torch.load(model_path)
  70. if "params_ema" in loadnet:
  71. keyname = "params_ema"
  72. else:
  73. keyname = "params"
  74. self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
  75. self.gfpgan.eval()
  76. self.gfpgan = self.gfpgan.to(self.device)
  77. @torch.no_grad()
  78. def enhance(
  79. self,
  80. img,
  81. has_aligned=False,
  82. only_center_face=False,
  83. paste_back=True,
  84. weight=0.5,
  85. ):
  86. self.face_helper.clean_all()
  87. if has_aligned: # the inputs are already aligned
  88. img = cv2.resize(img, (512, 512))
  89. self.face_helper.cropped_faces = [img]
  90. else:
  91. self.face_helper.read_image(img)
  92. # get face landmarks for each face
  93. self.face_helper.get_face_landmarks_5(
  94. only_center_face=only_center_face, eye_dist_threshold=5
  95. )
  96. # eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
  97. # TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
  98. # align and warp each face
  99. self.face_helper.align_warp_face()
  100. # face restoration
  101. for cropped_face in self.face_helper.cropped_faces:
  102. # prepare data
  103. cropped_face_t = img2tensor(
  104. cropped_face / 255.0, bgr2rgb=True, float32=True
  105. )
  106. normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
  107. cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
  108. try:
  109. output = self.gfpgan(cropped_face_t, return_rgb=False, weight=weight)[0]
  110. # convert to image
  111. restored_face = tensor2img(
  112. output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)
  113. )
  114. except RuntimeError as error:
  115. print(f"\tFailed inference for GFPGAN: {error}.")
  116. restored_face = cropped_face
  117. restored_face = restored_face.astype("uint8")
  118. self.face_helper.add_restored_face(restored_face)
  119. if not has_aligned and paste_back:
  120. # upsample the background
  121. if self.bg_upsampler is not None:
  122. # Now only support RealESRGAN for upsampling background
  123. bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
  124. else:
  125. bg_img = None
  126. self.face_helper.get_inverse_affine(None)
  127. # paste each restored face to the input image
  128. restored_img = self.face_helper.paste_faces_to_input_image(
  129. upsample_img=bg_img
  130. )
  131. return (
  132. self.face_helper.cropped_faces,
  133. self.face_helper.restored_faces,
  134. restored_img,
  135. )
  136. else:
  137. return self.face_helper.cropped_faces, self.face_helper.restored_faces, None