paint_by_example.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import cv2
  2. import PIL
  3. import PIL.Image
  4. import torch
  5. from loguru import logger
  6. from sorawm.iopaint.helper import decode_base64_to_image
  7. from sorawm.iopaint.schema import InpaintRequest
  8. from .base import DiffusionInpaintModel
  9. from .utils import enable_low_mem, get_torch_dtype, is_local_files_only
  10. class PaintByExample(DiffusionInpaintModel):
  11. name = "Fantasy-Studio/Paint-by-Example"
  12. pad_mod = 8
  13. min_size = 512
  14. def init_model(self, device: torch.device, **kwargs):
  15. from diffusers import DiffusionPipeline
  16. use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
  17. model_kwargs = {
  18. "local_files_only": is_local_files_only(**kwargs),
  19. }
  20. if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
  21. logger.info("Disable Paint By Example Model NSFW checker")
  22. model_kwargs.update(
  23. dict(safety_checker=None, requires_safety_checker=False)
  24. )
  25. self.model = DiffusionPipeline.from_pretrained(
  26. self.name, torch_dtype=torch_dtype, **model_kwargs
  27. )
  28. enable_low_mem(self.model, kwargs.get("low_mem", False))
  29. # TODO: gpu_id
  30. if kwargs.get("cpu_offload", False) and use_gpu:
  31. self.model.image_encoder = self.model.image_encoder.to(device)
  32. self.model.enable_sequential_cpu_offload(gpu_id=0)
  33. else:
  34. self.model = self.model.to(device)
  35. def forward(self, image, mask, config: InpaintRequest):
  36. """Input image and output image have same size
  37. image: [H, W, C] RGB
  38. mask: [H, W, 1] 255 means area to repaint
  39. return: BGR IMAGE
  40. """
  41. if config.paint_by_example_example_image is None:
  42. raise ValueError("paint_by_example_example_image is required")
  43. example_image, _, _, _ = decode_base64_to_image(
  44. config.paint_by_example_example_image
  45. )
  46. output = self.model(
  47. image=PIL.Image.fromarray(image),
  48. mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
  49. example_image=PIL.Image.fromarray(example_image),
  50. num_inference_steps=config.sd_steps,
  51. guidance_scale=config.sd_guidance_scale,
  52. negative_prompt="out of frame, lowres, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, disfigured, gross proportions, malformed limbs, watermark, signature",
  53. output_type="np.array",
  54. generator=torch.manual_seed(config.sd_seed),
  55. ).images[0]
  56. output = (output * 255).round().astype("uint8")
  57. output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
  58. return output