download.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. import glob
  2. import json
  3. import os
  4. from functools import lru_cache
  5. from pathlib import Path
  6. from typing import List, Optional
  7. from loguru import logger
  8. from sorawm.iopaint.const import (
  9. ANYTEXT_NAME,
  10. DEFAULT_MODEL_DIR,
  11. DIFFUSERS_SD_CLASS_NAME,
  12. DIFFUSERS_SD_INPAINT_CLASS_NAME,
  13. DIFFUSERS_SDXL_CLASS_NAME,
  14. DIFFUSERS_SDXL_INPAINT_CLASS_NAME,
  15. )
  16. from sorawm.iopaint.model.original_sd_configs import get_config_files
  17. from sorawm.iopaint.schema import ModelInfo, ModelType
  18. def cli_download_model(model: str):
  19. from sorawm.iopaint.model import models
  20. from sorawm.iopaint.model.utils import handle_from_pretrained_exceptions
  21. if model in models and models[model].is_erase_model:
  22. logger.info(f"Downloading {model}...")
  23. models[model].download()
  24. logger.info("Done.")
  25. elif model == ANYTEXT_NAME:
  26. logger.info(f"Downloading {model}...")
  27. models[model].download()
  28. logger.info("Done.")
  29. else:
  30. logger.info(f"Downloading model from Huggingface: {model}")
  31. from diffusers import DiffusionPipeline
  32. downloaded_path = handle_from_pretrained_exceptions(
  33. DiffusionPipeline.download, pretrained_model_name=model, variant="fp16"
  34. )
  35. logger.info(f"Done. Downloaded to {downloaded_path}")
  36. def folder_name_to_show_name(name: str) -> str:
  37. return name.replace("models--", "").replace("--", "/")
  38. @lru_cache(maxsize=512)
  39. def get_sd_model_type(model_abs_path: str) -> Optional[ModelType]:
  40. if "inpaint" in Path(model_abs_path).name.lower():
  41. model_type = ModelType.DIFFUSERS_SD_INPAINT
  42. else:
  43. # load once to check num_in_channels
  44. from diffusers import StableDiffusionInpaintPipeline
  45. try:
  46. StableDiffusionInpaintPipeline.from_single_file(
  47. model_abs_path,
  48. load_safety_checker=False,
  49. num_in_channels=9,
  50. original_config_file=get_config_files()["v1"],
  51. )
  52. model_type = ModelType.DIFFUSERS_SD_INPAINT
  53. except ValueError as e:
  54. if "[320, 4, 3, 3]" in str(e):
  55. model_type = ModelType.DIFFUSERS_SD
  56. else:
  57. logger.info(f"Ignore non sdxl file: {model_abs_path}")
  58. return
  59. except Exception as e:
  60. logger.error(f"Failed to load {model_abs_path}: {e}")
  61. return
  62. return model_type
  63. @lru_cache()
  64. def get_sdxl_model_type(model_abs_path: str) -> Optional[ModelType]:
  65. if "inpaint" in model_abs_path:
  66. model_type = ModelType.DIFFUSERS_SDXL_INPAINT
  67. else:
  68. # load once to check num_in_channels
  69. from diffusers import StableDiffusionXLInpaintPipeline
  70. try:
  71. model = StableDiffusionXLInpaintPipeline.from_single_file(
  72. model_abs_path,
  73. load_safety_checker=False,
  74. num_in_channels=9,
  75. original_config_file=get_config_files()["xl"],
  76. )
  77. if model.unet.config.in_channels == 9:
  78. # https://github.com/huggingface/diffusers/issues/6610
  79. model_type = ModelType.DIFFUSERS_SDXL_INPAINT
  80. else:
  81. model_type = ModelType.DIFFUSERS_SDXL
  82. except ValueError as e:
  83. if "[320, 4, 3, 3]" in str(e):
  84. model_type = ModelType.DIFFUSERS_SDXL
  85. else:
  86. logger.info(f"Ignore non sdxl file: {model_abs_path}")
  87. return
  88. except Exception as e:
  89. logger.error(f"Failed to load {model_abs_path}: {e}")
  90. return
  91. return model_type
  92. def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
  93. cache_dir = Path(cache_dir)
  94. stable_diffusion_dir = cache_dir / "stable_diffusion"
  95. cache_file = stable_diffusion_dir / "iopaint_cache.json"
  96. model_type_cache = {}
  97. if cache_file.exists():
  98. try:
  99. with open(cache_file, "r", encoding="utf-8") as f:
  100. model_type_cache = json.load(f)
  101. assert isinstance(model_type_cache, dict)
  102. except:
  103. pass
  104. res = []
  105. for it in stable_diffusion_dir.glob("*.*"):
  106. if it.suffix not in [".safetensors", ".ckpt"]:
  107. continue
  108. model_abs_path = str(it.absolute())
  109. model_type = model_type_cache.get(it.name)
  110. if model_type is None:
  111. model_type = get_sd_model_type(model_abs_path)
  112. if model_type is None:
  113. continue
  114. model_type_cache[it.name] = model_type
  115. res.append(
  116. ModelInfo(
  117. name=it.name,
  118. path=model_abs_path,
  119. model_type=model_type,
  120. is_single_file_diffusers=True,
  121. )
  122. )
  123. if stable_diffusion_dir.exists():
  124. with open(cache_file, "w", encoding="utf-8") as fw:
  125. json.dump(model_type_cache, fw, indent=2, ensure_ascii=False)
  126. stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl"
  127. sdxl_cache_file = stable_diffusion_xl_dir / "iopaint_cache.json"
  128. sdxl_model_type_cache = {}
  129. if sdxl_cache_file.exists():
  130. try:
  131. with open(sdxl_cache_file, "r", encoding="utf-8") as f:
  132. sdxl_model_type_cache = json.load(f)
  133. assert isinstance(sdxl_model_type_cache, dict)
  134. except:
  135. pass
  136. for it in stable_diffusion_xl_dir.glob("*.*"):
  137. if it.suffix not in [".safetensors", ".ckpt"]:
  138. continue
  139. model_abs_path = str(it.absolute())
  140. model_type = sdxl_model_type_cache.get(it.name)
  141. if model_type is None:
  142. model_type = get_sdxl_model_type(model_abs_path)
  143. if model_type is None:
  144. continue
  145. sdxl_model_type_cache[it.name] = model_type
  146. if stable_diffusion_xl_dir.exists():
  147. with open(sdxl_cache_file, "w", encoding="utf-8") as fw:
  148. json.dump(sdxl_model_type_cache, fw, indent=2, ensure_ascii=False)
  149. res.append(
  150. ModelInfo(
  151. name=it.name,
  152. path=model_abs_path,
  153. model_type=model_type,
  154. is_single_file_diffusers=True,
  155. )
  156. )
  157. return res
  158. def scan_inpaint_models(model_dir: Path) -> List[ModelInfo]:
  159. res = []
  160. from sorawm.iopaint.model import models
  161. # logger.info(f"Scanning inpaint models in {model_dir}")
  162. for name, m in models.items():
  163. if m.is_erase_model and m.is_downloaded():
  164. res.append(
  165. ModelInfo(
  166. name=name,
  167. path=name,
  168. model_type=ModelType.INPAINT,
  169. )
  170. )
  171. return res
  172. def scan_diffusers_models() -> List[ModelInfo]:
  173. from huggingface_hub.constants import HF_HUB_CACHE
  174. available_models = []
  175. cache_dir = Path(HF_HUB_CACHE)
  176. # logger.info(f"Scanning diffusers models in {cache_dir}")
  177. diffusers_model_names = []
  178. model_index_files = glob.glob(
  179. os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True
  180. )
  181. for it in model_index_files:
  182. it = Path(it)
  183. try:
  184. with open(it, "r", encoding="utf-8") as f:
  185. data = json.load(f)
  186. except:
  187. continue
  188. _class_name = data["_class_name"]
  189. name = folder_name_to_show_name(it.parent.parent.parent.name)
  190. if name in diffusers_model_names:
  191. continue
  192. if "PowerPaint" in name:
  193. model_type = ModelType.DIFFUSERS_OTHER
  194. elif _class_name == DIFFUSERS_SD_CLASS_NAME:
  195. model_type = ModelType.DIFFUSERS_SD
  196. elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME:
  197. model_type = ModelType.DIFFUSERS_SD_INPAINT
  198. elif _class_name == DIFFUSERS_SDXL_CLASS_NAME:
  199. model_type = ModelType.DIFFUSERS_SDXL
  200. elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME:
  201. model_type = ModelType.DIFFUSERS_SDXL_INPAINT
  202. elif _class_name in [
  203. "StableDiffusionInstructPix2PixPipeline",
  204. "PaintByExamplePipeline",
  205. "KandinskyV22InpaintPipeline",
  206. "AnyText",
  207. ]:
  208. model_type = ModelType.DIFFUSERS_OTHER
  209. else:
  210. continue
  211. diffusers_model_names.append(name)
  212. available_models.append(
  213. ModelInfo(
  214. name=name,
  215. path=name,
  216. model_type=model_type,
  217. )
  218. )
  219. return available_models
  220. def _scan_converted_diffusers_models(cache_dir) -> List[ModelInfo]:
  221. cache_dir = Path(cache_dir)
  222. available_models = []
  223. diffusers_model_names = []
  224. model_index_files = glob.glob(
  225. os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True
  226. )
  227. for it in model_index_files:
  228. it = Path(it)
  229. with open(it, "r", encoding="utf-8") as f:
  230. try:
  231. data = json.load(f)
  232. except:
  233. logger.error(
  234. f"Failed to load {it}, please try revert from original model or fix model_index.json by hand."
  235. )
  236. continue
  237. _class_name = data["_class_name"]
  238. name = folder_name_to_show_name(it.parent.name)
  239. if name in diffusers_model_names:
  240. continue
  241. elif _class_name == DIFFUSERS_SD_CLASS_NAME:
  242. model_type = ModelType.DIFFUSERS_SD
  243. elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME:
  244. model_type = ModelType.DIFFUSERS_SD_INPAINT
  245. elif _class_name == DIFFUSERS_SDXL_CLASS_NAME:
  246. model_type = ModelType.DIFFUSERS_SDXL
  247. elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME:
  248. model_type = ModelType.DIFFUSERS_SDXL_INPAINT
  249. else:
  250. continue
  251. diffusers_model_names.append(name)
  252. available_models.append(
  253. ModelInfo(
  254. name=name,
  255. path=str(it.parent.absolute()),
  256. model_type=model_type,
  257. )
  258. )
  259. return available_models
  260. def scan_converted_diffusers_models(cache_dir) -> List[ModelInfo]:
  261. cache_dir = Path(cache_dir)
  262. available_models = []
  263. stable_diffusion_dir = cache_dir / "stable_diffusion"
  264. stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl"
  265. available_models.extend(_scan_converted_diffusers_models(stable_diffusion_dir))
  266. available_models.extend(_scan_converted_diffusers_models(stable_diffusion_xl_dir))
  267. return available_models
  268. def scan_models() -> List[ModelInfo]:
  269. model_dir = os.getenv("XDG_CACHE_HOME", DEFAULT_MODEL_DIR)
  270. available_models = []
  271. available_models.extend(scan_inpaint_models(model_dir))
  272. available_models.extend(scan_single_file_diffusion_models(model_dir))
  273. available_models.extend(scan_diffusers_models())
  274. available_models.extend(scan_converted_diffusers_models(model_dir))
  275. return available_models