brushnet_xl_wrapper.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import cv2
  2. import numpy as np
  3. import PIL.Image
  4. import torch
  5. from loguru import logger
  6. from ...const import SDXL_BRUSHNET_CHOICES
  7. from ...schema import InpaintRequest, ModelType
  8. from ..base import DiffusionInpaintModel
  9. from ..helper.cpu_text_encoder import CPUTextEncoderWrapper
  10. from ..original_sd_configs import get_config_files
  11. from ..utils import (
  12. enable_low_mem,
  13. get_torch_dtype,
  14. handle_from_pretrained_exceptions,
  15. is_local_files_only,
  16. )
  17. from .brushnet import BrushNetModel
  18. from .brushnet_unet_forward import brushnet_unet_forward
  19. from .unet_2d_blocks import (
  20. CrossAttnDownBlock2D_forward,
  21. CrossAttnUpBlock2D_forward,
  22. DownBlock2D_forward,
  23. UpBlock2D_forward,
  24. )
  25. class BrushNetXLWrapper(DiffusionInpaintModel):
  26. pad_mod = 8
  27. min_size = 1024
  28. support_brushnet = True
  29. support_lcm_lora = False
  30. def init_model(self, device: torch.device, **kwargs):
  31. from .pipeline_brushnet_sd_xl import StableDiffusionXLBrushNetPipeline
  32. self.model_info = kwargs["model_info"]
  33. self.brushnet_xl_method = SDXL_BRUSHNET_CHOICES[0]
  34. # self.brushnet_xl_method = kwargs["brushnet_xl_method"]
  35. use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
  36. self.torch_dtype = torch_dtype
  37. model_kwargs = {
  38. **kwargs.get("pipe_components", {}),
  39. "local_files_only": is_local_files_only(**kwargs),
  40. }
  41. self.local_files_only = model_kwargs["local_files_only"]
  42. disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get(
  43. "cpu_offload", False
  44. )
  45. if disable_nsfw_checker:
  46. logger.info("Disable Stable Diffusion Model NSFW checker")
  47. model_kwargs.update(
  48. dict(
  49. safety_checker=None,
  50. feature_extractor=None,
  51. requires_safety_checker=False,
  52. )
  53. )
  54. logger.info(f"Loading BrushNet model from {self.brushnet_xl_method}")
  55. brushnet = BrushNetModel.from_pretrained(
  56. self.brushnet_xl_method, torch_dtype=torch_dtype
  57. )
  58. if self.model_info.is_single_file_diffusers:
  59. if self.model_info.model_type == ModelType.DIFFUSERS_SD:
  60. model_kwargs["num_in_channels"] = 4
  61. else:
  62. model_kwargs["num_in_channels"] = 9
  63. self.model = StableDiffusionXLBrushNetPipeline.from_single_file(
  64. self.model_id_or_path,
  65. torch_dtype=torch_dtype,
  66. load_safety_checker=not disable_nsfw_checker,
  67. original_config_file=get_config_files()["v1"],
  68. brushnet=brushnet,
  69. **model_kwargs,
  70. )
  71. else:
  72. self.model = handle_from_pretrained_exceptions(
  73. StableDiffusionXLBrushNetPipeline.from_pretrained,
  74. pretrained_model_name_or_path=self.model_id_or_path,
  75. variant="fp16",
  76. torch_dtype=torch_dtype,
  77. brushnet=brushnet,
  78. **model_kwargs,
  79. )
  80. enable_low_mem(self.model, kwargs.get("low_mem", False))
  81. if kwargs.get("cpu_offload", False) and use_gpu:
  82. logger.info("Enable sequential cpu offload")
  83. self.model.enable_sequential_cpu_offload(gpu_id=0)
  84. else:
  85. self.model = self.model.to(device)
  86. if kwargs["sd_cpu_textencoder"]:
  87. logger.info("Run Stable Diffusion TextEncoder on CPU")
  88. self.model.text_encoder = CPUTextEncoderWrapper(
  89. self.model.text_encoder, torch_dtype
  90. )
  91. self.callback = kwargs.pop("callback", None)
  92. # Monkey patch the forward method of the UNet to use the brushnet_unet_forward method
  93. self.model.unet.forward = brushnet_unet_forward.__get__(
  94. self.model.unet, self.model.unet.__class__
  95. )
  96. for down_block in self.model.brushnet.down_blocks:
  97. down_block.forward = DownBlock2D_forward.__get__(
  98. down_block, down_block.__class__
  99. )
  100. for up_block in self.model.brushnet.up_blocks:
  101. up_block.forward = UpBlock2D_forward.__get__(up_block, up_block.__class__)
  102. # Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward
  103. for down_block in self.model.unet.down_blocks:
  104. if down_block.__class__.__name__ == "CrossAttnDownBlock2D":
  105. down_block.forward = CrossAttnDownBlock2D_forward.__get__(
  106. down_block, down_block.__class__
  107. )
  108. else:
  109. down_block.forward = DownBlock2D_forward.__get__(
  110. down_block, down_block.__class__
  111. )
  112. for up_block in self.model.unet.up_blocks:
  113. if up_block.__class__.__name__ == "CrossAttnUpBlock2D":
  114. up_block.forward = CrossAttnUpBlock2D_forward.__get__(
  115. up_block, up_block.__class__
  116. )
  117. else:
  118. up_block.forward = UpBlock2D_forward.__get__(
  119. up_block, up_block.__class__
  120. )
  121. def switch_brushnet_method(self, new_method: str):
  122. self.brushnet_method = new_method
  123. brushnet_xl = BrushNetModel.from_pretrained(
  124. new_method,
  125. local_files_only=self.local_files_only,
  126. torch_dtype=self.torch_dtype,
  127. ).to(self.model.device)
  128. self.model.brushnet = brushnet_xl
  129. def forward(self, image, mask, config: InpaintRequest):
  130. """Input image and output image have same size
  131. image: [H, W, C] RGB
  132. mask: [H, W, 1] 255 means area to repaint
  133. return: BGR IMAGE
  134. """
  135. self.set_scheduler(config)
  136. img_h, img_w = image.shape[:2]
  137. normalized_mask = mask[:, :].astype("float32") / 255.0
  138. image = image * (1 - normalized_mask)
  139. image = image.astype(np.uint8)
  140. output = self.model(
  141. image=PIL.Image.fromarray(image),
  142. prompt=config.prompt,
  143. negative_prompt=config.negative_prompt,
  144. mask=PIL.Image.fromarray(mask[:, :, -1], mode="L").convert("RGB"),
  145. num_inference_steps=config.sd_steps,
  146. # strength=config.sd_strength,
  147. guidance_scale=config.sd_guidance_scale,
  148. output_type="np",
  149. callback_on_step_end=self.callback,
  150. height=img_h,
  151. width=img_w,
  152. generator=torch.manual_seed(config.sd_seed),
  153. brushnet_conditioning_scale=config.brushnet_conditioning_scale,
  154. ).images[0]
  155. output = (output * 255).round().astype("uint8")
  156. output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
  157. return output