sdxl.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import os
  2. import cv2
  3. import PIL.Image
  4. import torch
  5. from diffusers import AutoencoderKL
  6. from loguru import logger
  7. from sorawm.iopaint.schema import InpaintRequest, ModelType
  8. from .base import DiffusionInpaintModel
  9. from .helper.cpu_text_encoder import CPUTextEncoderWrapper
  10. from .original_sd_configs import get_config_files
  11. from .utils import (
  12. enable_low_mem,
  13. get_torch_dtype,
  14. handle_from_pretrained_exceptions,
  15. is_local_files_only,
  16. )
  17. class SDXL(DiffusionInpaintModel):
  18. name = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
  19. pad_mod = 8
  20. min_size = 512
  21. lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
  22. model_id_or_path = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
  23. def init_model(self, device: torch.device, **kwargs):
  24. from diffusers.pipelines import StableDiffusionXLInpaintPipeline
  25. use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
  26. if self.model_info.model_type == ModelType.DIFFUSERS_SDXL:
  27. num_in_channels = 4
  28. else:
  29. num_in_channels = 9
  30. if os.path.isfile(self.model_id_or_path):
  31. self.model = StableDiffusionXLInpaintPipeline.from_single_file(
  32. self.model_id_or_path,
  33. torch_dtype=torch_dtype,
  34. num_in_channels=num_in_channels,
  35. load_safety_checker=False,
  36. original_config_file=get_config_files()["xl"],
  37. )
  38. else:
  39. model_kwargs = {
  40. **kwargs.get("pipe_components", {}),
  41. "local_files_only": is_local_files_only(**kwargs),
  42. }
  43. if "vae" not in model_kwargs:
  44. vae = AutoencoderKL.from_pretrained(
  45. "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype
  46. )
  47. model_kwargs["vae"] = vae
  48. self.model = handle_from_pretrained_exceptions(
  49. StableDiffusionXLInpaintPipeline.from_pretrained,
  50. pretrained_model_name_or_path=self.model_id_or_path,
  51. torch_dtype=torch_dtype,
  52. variant="fp16",
  53. **model_kwargs,
  54. )
  55. enable_low_mem(self.model, kwargs.get("low_mem", False))
  56. if kwargs.get("cpu_offload", False) and use_gpu:
  57. logger.info("Enable sequential cpu offload")
  58. self.model.enable_sequential_cpu_offload(gpu_id=0)
  59. else:
  60. self.model = self.model.to(device)
  61. if kwargs["sd_cpu_textencoder"]:
  62. logger.info("Run Stable Diffusion TextEncoder on CPU")
  63. self.model.text_encoder = CPUTextEncoderWrapper(
  64. self.model.text_encoder, torch_dtype
  65. )
  66. self.model.text_encoder_2 = CPUTextEncoderWrapper(
  67. self.model.text_encoder_2, torch_dtype
  68. )
  69. self.callback = kwargs.pop("callback", None)
  70. def forward(self, image, mask, config: InpaintRequest):
  71. """Input image and output image have same size
  72. image: [H, W, C] RGB
  73. mask: [H, W, 1] 255 means area to repaint
  74. return: BGR IMAGE
  75. """
  76. self.set_scheduler(config)
  77. img_h, img_w = image.shape[:2]
  78. output = self.model(
  79. image=PIL.Image.fromarray(image),
  80. prompt=config.prompt,
  81. negative_prompt=config.negative_prompt,
  82. mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
  83. num_inference_steps=config.sd_steps,
  84. strength=0.999 if config.sd_strength == 1.0 else config.sd_strength,
  85. guidance_scale=config.sd_guidance_scale,
  86. output_type="np",
  87. callback_on_step_end=self.callback,
  88. height=img_h,
  89. width=img_w,
  90. generator=torch.manual_seed(config.sd_seed),
  91. ).images[0]
  92. output = (output * 255).round().astype("uint8")
  93. output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
  94. return output