controlnet.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import cv2
  2. import PIL.Image
  3. import torch
  4. from diffusers import ControlNetModel
  5. from loguru import logger
  6. from sorawm.iopaint.schema import InpaintRequest, ModelType
  7. from .base import DiffusionInpaintModel
  8. from .helper.controlnet_preprocess import (
  9. make_canny_control_image,
  10. make_depth_control_image,
  11. make_inpaint_control_image,
  12. make_openpose_control_image,
  13. )
  14. from .helper.cpu_text_encoder import CPUTextEncoderWrapper
  15. from .original_sd_configs import get_config_files
  16. from .utils import (
  17. enable_low_mem,
  18. get_scheduler,
  19. get_torch_dtype,
  20. handle_from_pretrained_exceptions,
  21. is_local_files_only,
  22. )
  23. class ControlNet(DiffusionInpaintModel):
  24. name = "controlnet"
  25. pad_mod = 8
  26. min_size = 512
  27. @property
  28. def lcm_lora_id(self):
  29. if self.model_info.model_type in [
  30. ModelType.DIFFUSERS_SD,
  31. ModelType.DIFFUSERS_SD_INPAINT,
  32. ]:
  33. return "latent-consistency/lcm-lora-sdv1-5"
  34. if self.model_info.model_type in [
  35. ModelType.DIFFUSERS_SDXL,
  36. ModelType.DIFFUSERS_SDXL_INPAINT,
  37. ]:
  38. return "latent-consistency/lcm-lora-sdxl"
  39. raise NotImplementedError(f"Unsupported controlnet lcm model {self.model_info}")
  40. def init_model(self, device: torch.device, **kwargs):
  41. model_info = kwargs["model_info"]
  42. controlnet_method = kwargs["controlnet_method"]
  43. self.model_info = model_info
  44. self.controlnet_method = controlnet_method
  45. model_kwargs = {
  46. **kwargs.get("pipe_components", {}),
  47. "local_files_only": is_local_files_only(**kwargs),
  48. }
  49. self.local_files_only = model_kwargs["local_files_only"]
  50. disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get(
  51. "cpu_offload", False
  52. )
  53. if disable_nsfw_checker:
  54. logger.info("Disable Stable Diffusion Model NSFW checker")
  55. model_kwargs.update(
  56. dict(
  57. safety_checker=None,
  58. feature_extractor=None,
  59. requires_safety_checker=False,
  60. )
  61. )
  62. use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
  63. self.torch_dtype = torch_dtype
  64. original_config_file_name = "v1"
  65. if model_info.model_type in [
  66. ModelType.DIFFUSERS_SD,
  67. ModelType.DIFFUSERS_SD_INPAINT,
  68. ]:
  69. from diffusers import StableDiffusionControlNetInpaintPipeline as PipeClass
  70. original_config_file_name = "v1"
  71. elif model_info.model_type in [
  72. ModelType.DIFFUSERS_SDXL,
  73. ModelType.DIFFUSERS_SDXL_INPAINT,
  74. ]:
  75. from diffusers import (
  76. StableDiffusionXLControlNetInpaintPipeline as PipeClass,
  77. )
  78. original_config_file_name = "xl"
  79. controlnet = ControlNetModel.from_pretrained(
  80. pretrained_model_name_or_path=controlnet_method,
  81. local_files_only=model_kwargs["local_files_only"],
  82. torch_dtype=self.torch_dtype,
  83. )
  84. if model_info.is_single_file_diffusers:
  85. if self.model_info.model_type == ModelType.DIFFUSERS_SD:
  86. model_kwargs["num_in_channels"] = 4
  87. else:
  88. model_kwargs["num_in_channels"] = 9
  89. self.model = PipeClass.from_single_file(
  90. model_info.path,
  91. controlnet=controlnet,
  92. load_safety_checker=not disable_nsfw_checker,
  93. torch_dtype=torch_dtype,
  94. original_config_file=get_config_files()[original_config_file_name],
  95. **model_kwargs,
  96. )
  97. else:
  98. self.model = handle_from_pretrained_exceptions(
  99. PipeClass.from_pretrained,
  100. pretrained_model_name_or_path=model_info.path,
  101. controlnet=controlnet,
  102. variant="fp16",
  103. torch_dtype=torch_dtype,
  104. **model_kwargs,
  105. )
  106. enable_low_mem(self.model, kwargs.get("low_mem", False))
  107. if kwargs.get("cpu_offload", False) and use_gpu:
  108. logger.info("Enable sequential cpu offload")
  109. self.model.enable_sequential_cpu_offload(gpu_id=0)
  110. else:
  111. self.model = self.model.to(device)
  112. if kwargs["sd_cpu_textencoder"]:
  113. logger.info("Run Stable Diffusion TextEncoder on CPU")
  114. self.model.text_encoder = CPUTextEncoderWrapper(
  115. self.model.text_encoder, torch_dtype
  116. )
  117. self.callback = kwargs.pop("callback", None)
  118. def switch_controlnet_method(self, new_method: str):
  119. self.controlnet_method = new_method
  120. controlnet = ControlNetModel.from_pretrained(
  121. new_method,
  122. local_files_only=self.local_files_only,
  123. torch_dtype=self.torch_dtype,
  124. ).to(self.model.device)
  125. self.model.controlnet = controlnet
  126. def _get_control_image(self, image, mask):
  127. if "canny" in self.controlnet_method:
  128. control_image = make_canny_control_image(image)
  129. elif "openpose" in self.controlnet_method:
  130. control_image = make_openpose_control_image(image)
  131. elif "depth" in self.controlnet_method:
  132. control_image = make_depth_control_image(image)
  133. elif "inpaint" in self.controlnet_method:
  134. control_image = make_inpaint_control_image(image, mask)
  135. else:
  136. raise NotImplementedError(f"{self.controlnet_method} not implemented")
  137. return control_image
  138. def forward(self, image, mask, config: InpaintRequest):
  139. """Input image and output image have same size
  140. image: [H, W, C] RGB
  141. mask: [H, W, 1] 255 means area to repaint
  142. return: BGR IMAGE
  143. """
  144. scheduler_config = self.model.scheduler.config
  145. scheduler = get_scheduler(config.sd_sampler, scheduler_config)
  146. self.model.scheduler = scheduler
  147. img_h, img_w = image.shape[:2]
  148. control_image = self._get_control_image(image, mask)
  149. mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L")
  150. image = PIL.Image.fromarray(image)
  151. output = self.model(
  152. image=image,
  153. mask_image=mask_image,
  154. control_image=control_image,
  155. prompt=config.prompt,
  156. negative_prompt=config.negative_prompt,
  157. num_inference_steps=config.sd_steps,
  158. guidance_scale=config.sd_guidance_scale,
  159. output_type="np",
  160. callback_on_step_end=self.callback,
  161. height=img_h,
  162. width=img_w,
  163. generator=torch.manual_seed(config.sd_seed),
  164. controlnet_conditioning_scale=config.controlnet_conditioning_scale,
  165. ).images[0]
  166. output = (output * 255).round().astype("uint8")
  167. output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
  168. return output