kandinsky.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import cv2
  2. import numpy as np
  3. import PIL.Image
  4. import torch
  5. from sorawm.iopaint.const import KANDINSKY22_NAME
  6. from sorawm.iopaint.schema import InpaintRequest
  7. from .base import DiffusionInpaintModel
  8. from .utils import enable_low_mem, get_torch_dtype, is_local_files_only
  9. class Kandinsky(DiffusionInpaintModel):
  10. pad_mod = 64
  11. min_size = 512
  12. def init_model(self, device: torch.device, **kwargs):
  13. from diffusers import AutoPipelineForInpainting
  14. use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
  15. model_kwargs = {
  16. "torch_dtype": torch_dtype,
  17. "local_files_only": is_local_files_only(**kwargs),
  18. }
  19. self.model = AutoPipelineForInpainting.from_pretrained(
  20. self.name, **model_kwargs
  21. ).to(device)
  22. enable_low_mem(self.model, kwargs.get("low_mem", False))
  23. self.callback = kwargs.pop("callback", None)
  24. def forward(self, image, mask, config: InpaintRequest):
  25. """Input image and output image have same size
  26. image: [H, W, C] RGB
  27. mask: [H, W, 1] 255 means area to repaint
  28. return: BGR IMAGE
  29. """
  30. self.set_scheduler(config)
  31. generator = torch.manual_seed(config.sd_seed)
  32. mask = mask.astype(np.float32) / 255
  33. img_h, img_w = image.shape[:2]
  34. # kandinsky 没有 strength
  35. output = self.model(
  36. prompt=config.prompt,
  37. negative_prompt=config.negative_prompt,
  38. image=PIL.Image.fromarray(image),
  39. mask_image=mask[:, :, 0],
  40. height=img_h,
  41. width=img_w,
  42. num_inference_steps=config.sd_steps,
  43. guidance_scale=config.sd_guidance_scale,
  44. output_type="np",
  45. callback_on_step_end=self.callback,
  46. generator=generator,
  47. ).images[0]
  48. output = (output * 255).round().astype("uint8")
  49. output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
  50. return output
  51. class Kandinsky22(Kandinsky):
  52. name = KANDINSKY22_NAME