power_paint_v2.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. from itertools import chain
  2. import cv2
  3. import numpy as np
  4. import PIL.Image
  5. import torch
  6. from loguru import logger
  7. from transformers import CLIPTextModel, CLIPTokenizer
  8. from sorawm.iopaint.model.original_sd_configs import get_config_files
  9. from sorawm.iopaint.schema import InpaintRequest, ModelType
  10. from ..base import DiffusionInpaintModel
  11. from ..helper.cpu_text_encoder import CPUTextEncoderWrapper
  12. from ..utils import (
  13. enable_low_mem,
  14. get_torch_dtype,
  15. handle_from_pretrained_exceptions,
  16. is_local_files_only,
  17. )
  18. from .powerpaint_tokenizer import task_to_prompt
  19. from .v2.BrushNet_CA import BrushNetModel
  20. from .v2.unet_2d_blocks import (
  21. CrossAttnDownBlock2D_forward,
  22. CrossAttnUpBlock2D_forward,
  23. DownBlock2D_forward,
  24. UpBlock2D_forward,
  25. )
  26. from .v2.unet_2d_condition import UNet2DConditionModel_forward
  27. class PowerPaintV2(DiffusionInpaintModel):
  28. pad_mod = 8
  29. min_size = 512
  30. lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
  31. hf_model_id = "Sanster/PowerPaint_v2"
  32. def init_model(self, device: torch.device, **kwargs):
  33. from .powerpaint_tokenizer import PowerPaintTokenizer
  34. from .v2.pipeline_PowerPaint_Brushnet_CA import (
  35. StableDiffusionPowerPaintBrushNetPipeline,
  36. )
  37. use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
  38. model_kwargs = {"local_files_only": is_local_files_only(**kwargs)}
  39. if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
  40. logger.info("Disable Stable Diffusion Model NSFW checker")
  41. model_kwargs.update(
  42. dict(
  43. safety_checker=None,
  44. feature_extractor=None,
  45. requires_safety_checker=False,
  46. )
  47. )
  48. text_encoder_brushnet = CLIPTextModel.from_pretrained(
  49. self.hf_model_id,
  50. subfolder="text_encoder_brushnet",
  51. variant="fp16",
  52. torch_dtype=torch_dtype,
  53. local_files_only=model_kwargs["local_files_only"],
  54. )
  55. brushnet = BrushNetModel.from_pretrained(
  56. self.hf_model_id,
  57. subfolder="PowerPaint_Brushnet",
  58. variant="fp16",
  59. torch_dtype=torch_dtype,
  60. local_files_only=model_kwargs["local_files_only"],
  61. )
  62. if self.model_info.is_single_file_diffusers:
  63. if self.model_info.model_type == ModelType.DIFFUSERS_SD:
  64. model_kwargs["num_in_channels"] = 4
  65. else:
  66. model_kwargs["num_in_channels"] = 9
  67. pipe = StableDiffusionPowerPaintBrushNetPipeline.from_single_file(
  68. self.model_id_or_path,
  69. torch_dtype=torch_dtype,
  70. load_safety_checker=False,
  71. original_config_file=get_config_files()["v1"],
  72. brushnet=brushnet,
  73. text_encoder_brushnet=text_encoder_brushnet,
  74. **model_kwargs,
  75. )
  76. else:
  77. pipe = handle_from_pretrained_exceptions(
  78. StableDiffusionPowerPaintBrushNetPipeline.from_pretrained,
  79. pretrained_model_name_or_path=self.model_id_or_path,
  80. torch_dtype=torch_dtype,
  81. brushnet=brushnet,
  82. text_encoder_brushnet=text_encoder_brushnet,
  83. variant="fp16",
  84. **model_kwargs,
  85. )
  86. pipe.tokenizer = PowerPaintTokenizer(
  87. CLIPTokenizer.from_pretrained(self.hf_model_id, subfolder="tokenizer")
  88. )
  89. self.model = pipe
  90. enable_low_mem(self.model, kwargs.get("low_mem", False))
  91. if kwargs.get("cpu_offload", False) and use_gpu:
  92. logger.info("Enable sequential cpu offload")
  93. self.model.enable_sequential_cpu_offload(gpu_id=0)
  94. else:
  95. self.model = self.model.to(device)
  96. if kwargs["sd_cpu_textencoder"]:
  97. logger.info("Run Stable Diffusion TextEncoder on CPU")
  98. self.model.text_encoder = CPUTextEncoderWrapper(
  99. self.model.text_encoder, torch_dtype
  100. )
  101. self.callback = kwargs.pop("callback", None)
  102. # Monkey patch the forward method of the UNet to use the brushnet_unet_forward method
  103. self.model.unet.forward = UNet2DConditionModel_forward.__get__(
  104. self.model.unet, self.model.unet.__class__
  105. )
  106. # Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward
  107. for down_block in chain(
  108. self.model.unet.down_blocks, self.model.brushnet.down_blocks
  109. ):
  110. if down_block.__class__.__name__ == "CrossAttnDownBlock2D":
  111. down_block.forward = CrossAttnDownBlock2D_forward.__get__(
  112. down_block, down_block.__class__
  113. )
  114. else:
  115. down_block.forward = DownBlock2D_forward.__get__(
  116. down_block, down_block.__class__
  117. )
  118. for up_block in chain(self.model.unet.up_blocks, self.model.brushnet.up_blocks):
  119. if up_block.__class__.__name__ == "CrossAttnUpBlock2D":
  120. up_block.forward = CrossAttnUpBlock2D_forward.__get__(
  121. up_block, up_block.__class__
  122. )
  123. else:
  124. up_block.forward = UpBlock2D_forward.__get__(
  125. up_block, up_block.__class__
  126. )
  127. def forward(self, image, mask, config: InpaintRequest):
  128. """Input image and output image have same size
  129. image: [H, W, C] RGB
  130. mask: [H, W, 1] 255 means area to repaint
  131. return: BGR IMAGE
  132. """
  133. self.set_scheduler(config)
  134. image = image * (1 - mask / 255.0)
  135. img_h, img_w = image.shape[:2]
  136. image = PIL.Image.fromarray(image.astype(np.uint8))
  137. mask = PIL.Image.fromarray(mask[:, :, -1], mode="L").convert("RGB")
  138. promptA, promptB, negative_promptA, negative_promptB = task_to_prompt(
  139. config.powerpaint_task
  140. )
  141. output = self.model(
  142. image=image,
  143. mask=mask,
  144. promptA=promptA,
  145. promptB=promptB,
  146. promptU=config.prompt,
  147. tradoff=config.fitting_degree,
  148. tradoff_nag=config.fitting_degree,
  149. negative_promptA=negative_promptA,
  150. negative_promptB=negative_promptB,
  151. negative_promptU=config.negative_prompt,
  152. num_inference_steps=config.sd_steps,
  153. # strength=config.sd_strength,
  154. brushnet_conditioning_scale=1.0,
  155. guidance_scale=config.sd_guidance_scale,
  156. output_type="np",
  157. callback_on_step_end=self.callback,
  158. height=img_h,
  159. width=img_w,
  160. generator=torch.manual_seed(config.sd_seed),
  161. ).images[0]
  162. output = (output * 255).round().astype("uint8")
  163. output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
  164. return output