sd.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import cv2
  2. import PIL.Image
  3. import torch
  4. from loguru import logger
  5. from sorawm.iopaint.schema import InpaintRequest, ModelType
  6. from .base import DiffusionInpaintModel
  7. from .helper.cpu_text_encoder import CPUTextEncoderWrapper
  8. from .original_sd_configs import get_config_files
  9. from .utils import (
  10. enable_low_mem,
  11. get_torch_dtype,
  12. handle_from_pretrained_exceptions,
  13. is_local_files_only,
  14. )
  15. class SD(DiffusionInpaintModel):
  16. pad_mod = 8
  17. min_size = 512
  18. lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
  19. def init_model(self, device: torch.device, **kwargs):
  20. from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
  21. use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
  22. model_kwargs = {
  23. **kwargs.get("pipe_components", {}),
  24. "local_files_only": is_local_files_only(**kwargs),
  25. }
  26. disable_nsfw_checker = kwargs.get("disable_nsfw", False) or kwargs.get(
  27. "cpu_offload", False
  28. )
  29. if disable_nsfw_checker:
  30. logger.info("Disable Stable Diffusion Model NSFW checker")
  31. model_kwargs.update(
  32. dict(
  33. safety_checker=None,
  34. feature_extractor=None,
  35. requires_safety_checker=False,
  36. )
  37. )
  38. if self.model_info.is_single_file_diffusers:
  39. if self.model_info.model_type == ModelType.DIFFUSERS_SD:
  40. model_kwargs["num_in_channels"] = 4
  41. else:
  42. model_kwargs["num_in_channels"] = 9
  43. self.model = StableDiffusionInpaintPipeline.from_single_file(
  44. self.model_id_or_path,
  45. torch_dtype=torch_dtype,
  46. load_safety_checker=not disable_nsfw_checker,
  47. original_config_file=get_config_files()["v1"],
  48. **model_kwargs,
  49. )
  50. else:
  51. self.model = handle_from_pretrained_exceptions(
  52. StableDiffusionInpaintPipeline.from_pretrained,
  53. pretrained_model_name_or_path=self.model_id_or_path,
  54. variant="fp16",
  55. torch_dtype=torch_dtype,
  56. **model_kwargs,
  57. )
  58. enable_low_mem(self.model, kwargs.get("low_mem", False))
  59. if kwargs.get("cpu_offload", False) and use_gpu:
  60. logger.info("Enable sequential cpu offload")
  61. self.model.enable_sequential_cpu_offload(gpu_id=0)
  62. else:
  63. self.model = self.model.to(device)
  64. if kwargs.get("sd_cpu_textencoder", False):
  65. logger.info("Run Stable Diffusion TextEncoder on CPU")
  66. self.model.text_encoder = CPUTextEncoderWrapper(
  67. self.model.text_encoder, torch_dtype
  68. )
  69. self.callback = kwargs.pop("callback", None)
  70. def forward(self, image, mask, config: InpaintRequest):
  71. """Input image and output image have same size
  72. image: [H, W, C] RGB
  73. mask: [H, W, 1] 255 means area to repaint
  74. return: BGR IMAGE
  75. """
  76. self.set_scheduler(config)
  77. img_h, img_w = image.shape[:2]
  78. output = self.model(
  79. image=PIL.Image.fromarray(image),
  80. prompt=config.prompt,
  81. negative_prompt=config.negative_prompt,
  82. mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
  83. num_inference_steps=config.sd_steps,
  84. strength=config.sd_strength,
  85. guidance_scale=config.sd_guidance_scale,
  86. output_type="np",
  87. callback_on_step_end=self.callback,
  88. height=img_h,
  89. width=img_w,
  90. generator=torch.manual_seed(config.sd_seed),
  91. ).images[0]
  92. output = (output * 255).round().astype("uint8")
  93. output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
  94. return output
  95. class SD15(SD):
  96. name = "runwayml/stable-diffusion-inpainting"
  97. model_id_or_path = "runwayml/stable-diffusion-inpainting"
  98. class Anything4(SD):
  99. name = "Sanster/anything-4.0-inpainting"
  100. model_id_or_path = "Sanster/anything-4.0-inpainting"
  101. class RealisticVision14(SD):
  102. name = "Sanster/Realistic_Vision_V1.4-inpainting"
  103. model_id_or_path = "Sanster/Realistic_Vision_V1.4-inpainting"
  104. class SD2(SD):
  105. name = "stabilityai/stable-diffusion-2-inpainting"
  106. model_id_or_path = "stabilityai/stable-diffusion-2-inpainting"