brushnet_wrapper.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import cv2
  2. import numpy as np
  3. import PIL.Image
  4. import torch
  5. from loguru import logger
  6. from ...schema import InpaintRequest, ModelType
  7. from ..base import DiffusionInpaintModel
  8. from ..helper.cpu_text_encoder import CPUTextEncoderWrapper
  9. from ..original_sd_configs import get_config_files
  10. from ..utils import (
  11. enable_low_mem,
  12. get_torch_dtype,
  13. handle_from_pretrained_exceptions,
  14. is_local_files_only,
  15. )
  16. from .brushnet import BrushNetModel
  17. from .brushnet_unet_forward import brushnet_unet_forward
  18. from .unet_2d_blocks import (
  19. CrossAttnDownBlock2D_forward,
  20. CrossAttnUpBlock2D_forward,
  21. DownBlock2D_forward,
  22. UpBlock2D_forward,
  23. )
  24. class BrushNetWrapper(DiffusionInpaintModel):
  25. pad_mod = 8
  26. min_size = 512
  27. def init_model(self, device: torch.device, **kwargs):
  28. from .pipeline_brushnet import StableDiffusionBrushNetPipeline
  29. self.model_info = kwargs["model_info"]
  30. self.brushnet_method = kwargs["brushnet_method"]
  31. use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
  32. self.torch_dtype = torch_dtype
  33. model_kwargs = {
  34. **kwargs.get("pipe_components", {}),
  35. "local_files_only": is_local_files_only(**kwargs),
  36. }
  37. self.local_files_only = model_kwargs["local_files_only"]
  38. disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get(
  39. "cpu_offload", False
  40. )
  41. if disable_nsfw_checker:
  42. logger.info("Disable Stable Diffusion Model NSFW checker")
  43. model_kwargs.update(
  44. dict(
  45. safety_checker=None,
  46. feature_extractor=None,
  47. requires_safety_checker=False,
  48. )
  49. )
  50. logger.info(f"Loading BrushNet model from {self.brushnet_method}")
  51. brushnet = BrushNetModel.from_pretrained(
  52. self.brushnet_method, torch_dtype=torch_dtype
  53. )
  54. if self.model_info.is_single_file_diffusers:
  55. if self.model_info.model_type == ModelType.DIFFUSERS_SD:
  56. model_kwargs["num_in_channels"] = 4
  57. else:
  58. model_kwargs["num_in_channels"] = 9
  59. self.model = StableDiffusionBrushNetPipeline.from_single_file(
  60. self.model_id_or_path,
  61. torch_dtype=torch_dtype,
  62. load_safety_checker=not disable_nsfw_checker,
  63. original_config_file=get_config_files()["v1"],
  64. brushnet=brushnet,
  65. **model_kwargs,
  66. )
  67. else:
  68. self.model = handle_from_pretrained_exceptions(
  69. StableDiffusionBrushNetPipeline.from_pretrained,
  70. pretrained_model_name_or_path=self.model_id_or_path,
  71. variant="fp16",
  72. torch_dtype=torch_dtype,
  73. brushnet=brushnet,
  74. **model_kwargs,
  75. )
  76. enable_low_mem(self.model, kwargs.get("low_mem", False))
  77. if kwargs.get("cpu_offload", False) and use_gpu:
  78. logger.info("Enable sequential cpu offload")
  79. self.model.enable_sequential_cpu_offload(gpu_id=0)
  80. else:
  81. self.model = self.model.to(device)
  82. if kwargs["sd_cpu_textencoder"]:
  83. logger.info("Run Stable Diffusion TextEncoder on CPU")
  84. self.model.text_encoder = CPUTextEncoderWrapper(
  85. self.model.text_encoder, torch_dtype
  86. )
  87. self.callback = kwargs.pop("callback", None)
  88. # Monkey patch the forward method of the UNet to use the brushnet_unet_forward method
  89. self.model.unet.forward = brushnet_unet_forward.__get__(
  90. self.model.unet, self.model.unet.__class__
  91. )
  92. for down_block in self.model.brushnet.down_blocks:
  93. down_block.forward = DownBlock2D_forward.__get__(
  94. down_block, down_block.__class__
  95. )
  96. for up_block in self.model.brushnet.up_blocks:
  97. up_block.forward = UpBlock2D_forward.__get__(up_block, up_block.__class__)
  98. # Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward
  99. for down_block in self.model.unet.down_blocks:
  100. if down_block.__class__.__name__ == "CrossAttnDownBlock2D":
  101. down_block.forward = CrossAttnDownBlock2D_forward.__get__(
  102. down_block, down_block.__class__
  103. )
  104. else:
  105. down_block.forward = DownBlock2D_forward.__get__(
  106. down_block, down_block.__class__
  107. )
  108. for up_block in self.model.unet.up_blocks:
  109. if up_block.__class__.__name__ == "CrossAttnUpBlock2D":
  110. up_block.forward = CrossAttnUpBlock2D_forward.__get__(
  111. up_block, up_block.__class__
  112. )
  113. else:
  114. up_block.forward = UpBlock2D_forward.__get__(
  115. up_block, up_block.__class__
  116. )
  117. def switch_brushnet_method(self, new_method: str):
  118. self.brushnet_method = new_method
  119. brushnet = BrushNetModel.from_pretrained(
  120. new_method,
  121. local_files_only=self.local_files_only,
  122. torch_dtype=self.torch_dtype,
  123. ).to(self.model.device)
  124. self.model.brushnet = brushnet
  125. def forward(self, image, mask, config: InpaintRequest):
  126. """Input image and output image have same size
  127. image: [H, W, C] RGB
  128. mask: [H, W, 1] 255 means area to repaint
  129. return: BGR IMAGE
  130. """
  131. self.set_scheduler(config)
  132. img_h, img_w = image.shape[:2]
  133. normalized_mask = mask[:, :].astype("float32") / 255.0
  134. image = image * (1 - normalized_mask)
  135. image = image.astype(np.uint8)
  136. output = self.model(
  137. image=PIL.Image.fromarray(image),
  138. prompt=config.prompt,
  139. negative_prompt=config.negative_prompt,
  140. mask=PIL.Image.fromarray(mask[:, :, -1], mode="L").convert("RGB"),
  141. num_inference_steps=config.sd_steps,
  142. # strength=config.sd_strength,
  143. guidance_scale=config.sd_guidance_scale,
  144. output_type="np",
  145. callback_on_step_end=self.callback,
  146. height=img_h,
  147. width=img_w,
  148. generator=torch.manual_seed(config.sd_seed),
  149. brushnet_conditioning_scale=config.brushnet_conditioning_scale,
  150. ).images[0]
  151. output = (output * 255).round().astype("uint8")
  152. output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
  153. return output