model_manager.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. from typing import Dict, List
  2. import numpy as np
  3. import torch
  4. from loguru import logger
  5. from sorawm.iopaint.download import scan_models
  6. from sorawm.iopaint.helper import switch_mps_device
  7. from sorawm.iopaint.model import SD, SDXL, ControlNet, models
  8. from sorawm.iopaint.model.brushnet.brushnet_wrapper import BrushNetWrapper
  9. from sorawm.iopaint.model.brushnet.brushnet_xl_wrapper import BrushNetXLWrapper
  10. from sorawm.iopaint.model.power_paint.power_paint_v2 import PowerPaintV2
  11. from sorawm.iopaint.model.utils import is_local_files_only, torch_gc
  12. from sorawm.iopaint.schema import InpaintRequest, ModelInfo, ModelType
  13. class ModelManager:
  14. def __init__(self, name: str, device: torch.device, **kwargs):
  15. self.name = name
  16. self.device = device
  17. self.kwargs = kwargs
  18. self.available_models: Dict[str, ModelInfo] = {}
  19. self.scan_models()
  20. self.enable_controlnet = kwargs.get("enable_controlnet", False)
  21. controlnet_method = kwargs.get("controlnet_method", None)
  22. if (
  23. controlnet_method is None
  24. and name in self.available_models
  25. and self.available_models[name].support_controlnet
  26. ):
  27. controlnet_method = self.available_models[name].controlnets[0]
  28. self.controlnet_method = controlnet_method
  29. self.enable_brushnet = kwargs.get("enable_brushnet", False)
  30. self.brushnet_method = kwargs.get("brushnet_method", None)
  31. self.enable_powerpaint_v2 = kwargs.get("enable_powerpaint_v2", False)
  32. self.model = self.init_model(name, device, **kwargs)
  33. @property
  34. def current_model(self) -> ModelInfo:
  35. return self.available_models[self.name]
  36. def init_model(self, name: str, device, **kwargs):
  37. logger.info(f"Loading model: {name}")
  38. if name not in self.available_models:
  39. raise NotImplementedError(
  40. f"Unsupported model: {name}. Available models: {list(self.available_models.keys())}"
  41. )
  42. model_info = self.available_models[name]
  43. kwargs = {
  44. **kwargs,
  45. "model_info": model_info,
  46. "enable_controlnet": self.enable_controlnet,
  47. "controlnet_method": self.controlnet_method,
  48. "enable_brushnet": self.enable_brushnet,
  49. "brushnet_method": self.brushnet_method,
  50. }
  51. if model_info.support_controlnet and self.enable_controlnet:
  52. return ControlNet(device, **kwargs)
  53. if model_info.support_brushnet and self.enable_brushnet:
  54. if model_info.model_type == ModelType.DIFFUSERS_SD:
  55. return BrushNetWrapper(device, **kwargs)
  56. elif model_info.model_type == ModelType.DIFFUSERS_SDXL:
  57. return BrushNetXLWrapper(device, **kwargs)
  58. if model_info.support_powerpaint_v2 and self.enable_powerpaint_v2:
  59. return PowerPaintV2(device, **kwargs)
  60. if model_info.name in models:
  61. return models[name](device, **kwargs)
  62. if model_info.model_type in [
  63. ModelType.DIFFUSERS_SD_INPAINT,
  64. ModelType.DIFFUSERS_SD,
  65. ]:
  66. return SD(device, **kwargs)
  67. if model_info.model_type in [
  68. ModelType.DIFFUSERS_SDXL_INPAINT,
  69. ModelType.DIFFUSERS_SDXL,
  70. ]:
  71. return SDXL(device, **kwargs)
  72. raise NotImplementedError(f"Unsupported model: {name}")
  73. @torch.inference_mode()
  74. def __call__(self, image, mask, config: InpaintRequest):
  75. """
  76. Args:
  77. image: [H, W, C] RGB
  78. mask: [H, W, 1] 255 means area to repaint
  79. config:
  80. Returns:
  81. BGR image
  82. """
  83. if config.enable_controlnet:
  84. self.switch_controlnet_method(config)
  85. if config.enable_brushnet:
  86. self.switch_brushnet_method(config)
  87. self.enable_disable_powerpaint_v2(config)
  88. self.enable_disable_lcm_lora(config)
  89. return self.model(image, mask, config).astype(np.uint8)
  90. def scan_models(self) -> List[ModelInfo]:
  91. available_models = scan_models()
  92. self.available_models = {it.name: it for it in available_models}
  93. return available_models
  94. def switch(self, new_name: str):
  95. if new_name == self.name:
  96. return
  97. old_name = self.name
  98. old_controlnet_method = self.controlnet_method
  99. self.name = new_name
  100. if (
  101. self.available_models[new_name].support_controlnet
  102. and self.controlnet_method
  103. not in self.available_models[new_name].controlnets
  104. ):
  105. self.controlnet_method = self.available_models[new_name].controlnets[0]
  106. try:
  107. # TODO: enable/disable controlnet without reload model
  108. del self.model
  109. torch_gc()
  110. self.model = self.init_model(
  111. new_name, switch_mps_device(new_name, self.device), **self.kwargs
  112. )
  113. except Exception as e:
  114. self.name = old_name
  115. self.controlnet_method = old_controlnet_method
  116. logger.info(f"Switch model from {old_name} to {new_name} failed, rollback")
  117. self.model = self.init_model(
  118. old_name, switch_mps_device(old_name, self.device), **self.kwargs
  119. )
  120. raise e
  121. def switch_brushnet_method(self, config):
  122. if not self.available_models[self.name].support_brushnet:
  123. return
  124. if (
  125. self.enable_brushnet
  126. and config.brushnet_method
  127. and self.brushnet_method != config.brushnet_method
  128. ):
  129. old_brushnet_method = self.brushnet_method
  130. self.brushnet_method = config.brushnet_method
  131. self.model.switch_brushnet_method(config.brushnet_method)
  132. logger.info(
  133. f"Switch Brushnet method from {old_brushnet_method} to {config.brushnet_method}"
  134. )
  135. elif self.enable_brushnet != config.enable_brushnet:
  136. self.enable_brushnet = config.enable_brushnet
  137. self.brushnet_method = config.brushnet_method
  138. pipe_components = {
  139. "vae": self.model.model.vae,
  140. "text_encoder": self.model.model.text_encoder,
  141. "unet": self.model.model.unet,
  142. }
  143. if hasattr(self.model.model, "text_encoder_2"):
  144. pipe_components["text_encoder_2"] = self.model.model.text_encoder_2
  145. if hasattr(self.model.model, "tokenizer"):
  146. pipe_components["tokenizer"] = self.model.model.tokenizer
  147. if hasattr(self.model.model, "tokenizer_2"):
  148. pipe_components["tokenizer_2"] = self.model.model.tokenizer_2
  149. self.model = self.init_model(
  150. self.name,
  151. switch_mps_device(self.name, self.device),
  152. pipe_components=pipe_components,
  153. **self.kwargs,
  154. )
  155. if not config.enable_brushnet:
  156. logger.info("BrushNet Disabled")
  157. else:
  158. logger.info("BrushNet Enabled")
  159. def switch_controlnet_method(self, config):
  160. if not self.available_models[self.name].support_controlnet:
  161. return
  162. if (
  163. self.enable_controlnet
  164. and config.controlnet_method
  165. and self.controlnet_method != config.controlnet_method
  166. ):
  167. old_controlnet_method = self.controlnet_method
  168. self.controlnet_method = config.controlnet_method
  169. self.model.switch_controlnet_method(config.controlnet_method)
  170. logger.info(
  171. f"Switch Controlnet method from {old_controlnet_method} to {config.controlnet_method}"
  172. )
  173. elif self.enable_controlnet != config.enable_controlnet:
  174. self.enable_controlnet = config.enable_controlnet
  175. self.controlnet_method = config.controlnet_method
  176. pipe_components = {
  177. "vae": self.model.model.vae,
  178. "text_encoder": self.model.model.text_encoder,
  179. "unet": self.model.model.unet,
  180. }
  181. if hasattr(self.model.model, "text_encoder_2"):
  182. pipe_components["text_encoder_2"] = self.model.model.text_encoder_2
  183. self.model = self.init_model(
  184. self.name,
  185. switch_mps_device(self.name, self.device),
  186. pipe_components=pipe_components,
  187. **self.kwargs,
  188. )
  189. if not config.enable_controlnet:
  190. logger.info("Disable controlnet")
  191. else:
  192. logger.info(f"Enable controlnet: {config.controlnet_method}")
  193. def enable_disable_powerpaint_v2(self, config: InpaintRequest):
  194. if not self.available_models[self.name].support_powerpaint_v2:
  195. return
  196. if self.enable_powerpaint_v2 != config.enable_powerpaint_v2:
  197. self.enable_powerpaint_v2 = config.enable_powerpaint_v2
  198. pipe_components = {"vae": self.model.model.vae}
  199. self.model = self.init_model(
  200. self.name,
  201. switch_mps_device(self.name, self.device),
  202. pipe_components=pipe_components,
  203. **self.kwargs,
  204. )
  205. if config.enable_powerpaint_v2:
  206. logger.info("Enable PowerPaintV2")
  207. else:
  208. logger.info("Disable PowerPaintV2")
  209. def enable_disable_lcm_lora(self, config: InpaintRequest):
  210. if self.available_models[self.name].support_lcm_lora:
  211. # TODO: change this if load other lora is supported
  212. lcm_lora_loaded = bool(self.model.model.get_list_adapters())
  213. if config.sd_lcm_lora:
  214. if not lcm_lora_loaded:
  215. logger.info("Load LCM LORA")
  216. self.model.model.load_lora_weights(
  217. self.model.lcm_lora_id,
  218. weight_name="pytorch_lora_weights.safetensors",
  219. local_files_only=is_local_files_only(),
  220. )
  221. else:
  222. logger.info("Enable LCM LORA")
  223. self.model.model.enable_lora()
  224. else:
  225. if lcm_lora_loaded:
  226. logger.info("Disable LCM LORA")
  227. self.model.model.disable_lora()