cli.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. import webbrowser
  2. from contextlib import asynccontextmanager
  3. from pathlib import Path
  4. from typing import Optional
  5. import typer
  6. from fastapi import FastAPI
  7. from loguru import logger
  8. from typer import Option
  9. from typer_config import use_json_config
  10. from sorawm.iopaint.const import *
  11. from sorawm.iopaint.runtime import check_device, dump_environment_info, setup_model_dir
  12. from sorawm.iopaint.schema import (
  13. Device,
  14. InteractiveSegModel,
  15. RealESRGANModel,
  16. RemoveBGModel,
  17. )
  18. typer_app = typer.Typer(pretty_exceptions_show_locals=False, add_completion=False)
  19. @typer_app.command(help="Install all plugins dependencies")
  20. def install_plugins_packages():
  21. from sorawm.iopaint.installer import install_plugins_package
  22. install_plugins_package()
  23. @typer_app.command(help="Download SD/SDXL normal/inpainting model from HuggingFace")
  24. def download(
  25. model: str = Option(
  26. ..., help="Model id on HuggingFace e.g: runwayml/stable-diffusion-inpainting"
  27. ),
  28. model_dir: Path = Option(
  29. DEFAULT_MODEL_DIR,
  30. help=MODEL_DIR_HELP,
  31. file_okay=False,
  32. callback=setup_model_dir,
  33. ),
  34. ):
  35. from sorawm.iopaint.download import cli_download_model
  36. cli_download_model(model)
  37. @typer_app.command(name="list", help="List downloaded models")
  38. def list_model(
  39. model_dir: Path = Option(
  40. DEFAULT_MODEL_DIR,
  41. help=MODEL_DIR_HELP,
  42. file_okay=False,
  43. callback=setup_model_dir,
  44. ),
  45. ):
  46. from sorawm.iopaint.download import scan_models
  47. scanned_models = scan_models()
  48. for it in scanned_models:
  49. print(it.name)
  50. @typer_app.command(help="Batch processing images")
  51. def run(
  52. model: str = Option("lama"),
  53. device: Device = Option(Device.cpu),
  54. image: Path = Option(..., help="Image folders or file path"),
  55. mask: Path = Option(
  56. ...,
  57. help="Mask folders or file path. "
  58. "If it is a directory, the mask images in the directory should have the same name as the original image."
  59. "If it is a file, all images will use this mask."
  60. "Mask will automatically resize to the same size as the original image.",
  61. ),
  62. output: Path = Option(..., help="Output directory or file path"),
  63. config: Path = Option(
  64. None, help="Config file path. You can use dump command to create a base config."
  65. ),
  66. concat: bool = Option(
  67. False, help="Concat original image, mask and output images into one image"
  68. ),
  69. model_dir: Path = Option(
  70. DEFAULT_MODEL_DIR,
  71. help=MODEL_DIR_HELP,
  72. file_okay=False,
  73. callback=setup_model_dir,
  74. ),
  75. ):
  76. from sorawm.iopaint.download import cli_download_model, scan_models
  77. scanned_models = scan_models()
  78. if model not in [it.name for it in scanned_models]:
  79. logger.info(f"{model} not found in {model_dir}, try to downloading")
  80. cli_download_model(model)
  81. from sorawm.iopaint.batch_processing import batch_inpaint
  82. batch_inpaint(model, device, image, mask, output, config, concat)
  83. @typer_app.command(help="Start IOPaint server")
  84. @use_json_config()
  85. def start(
  86. host: str = Option("127.0.0.1"),
  87. port: int = Option(8080),
  88. inbrowser: bool = Option(False, help=INBROWSER_HELP),
  89. model: str = Option(
  90. DEFAULT_MODEL,
  91. help=f"Erase models: [{', '.join(AVAILABLE_MODELS)}].\n"
  92. f"Diffusion models: [{', '.join(DIFFUSION_MODELS)}] or any SD/SDXL normal/inpainting models on HuggingFace.",
  93. ),
  94. model_dir: Path = Option(
  95. DEFAULT_MODEL_DIR,
  96. help=MODEL_DIR_HELP,
  97. dir_okay=True,
  98. file_okay=False,
  99. callback=setup_model_dir,
  100. ),
  101. low_mem: bool = Option(False, help=LOW_MEM_HELP),
  102. no_half: bool = Option(False, help=NO_HALF_HELP),
  103. cpu_offload: bool = Option(False, help=CPU_OFFLOAD_HELP),
  104. disable_nsfw_checker: bool = Option(False, help=DISABLE_NSFW_HELP),
  105. cpu_textencoder: bool = Option(False, help=CPU_TEXTENCODER_HELP),
  106. local_files_only: bool = Option(False, help=LOCAL_FILES_ONLY_HELP),
  107. device: Device = Option(Device.cpu),
  108. input: Optional[Path] = Option(None, help=INPUT_HELP),
  109. mask_dir: Optional[Path] = Option(
  110. None, help=MODEL_DIR_HELP, dir_okay=True, file_okay=False
  111. ),
  112. output_dir: Optional[Path] = Option(
  113. None, help=OUTPUT_DIR_HELP, dir_okay=True, file_okay=False
  114. ),
  115. quality: int = Option(100, help=QUALITY_HELP),
  116. enable_interactive_seg: bool = Option(False, help=INTERACTIVE_SEG_HELP),
  117. interactive_seg_model: InteractiveSegModel = Option(
  118. InteractiveSegModel.sam2_1_tiny, help=INTERACTIVE_SEG_MODEL_HELP
  119. ),
  120. interactive_seg_device: Device = Option(Device.cpu),
  121. enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP),
  122. remove_bg_device: Device = Option(Device.cpu, help=REMOVE_BG_DEVICE_HELP),
  123. remove_bg_model: RemoveBGModel = Option(RemoveBGModel.briaai_rmbg_1_4),
  124. enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP),
  125. enable_realesrgan: bool = Option(False),
  126. realesrgan_device: Device = Option(Device.cpu),
  127. realesrgan_model: RealESRGANModel = Option(RealESRGANModel.realesr_general_x4v3),
  128. enable_gfpgan: bool = Option(False),
  129. gfpgan_device: Device = Option(Device.cpu),
  130. enable_restoreformer: bool = Option(False),
  131. restoreformer_device: Device = Option(Device.cpu),
  132. ):
  133. dump_environment_info()
  134. device = check_device(device)
  135. remove_bg_device = check_device(remove_bg_device)
  136. realesrgan_device = check_device(realesrgan_device)
  137. gfpgan_device = check_device(gfpgan_device)
  138. if input and not input.exists():
  139. logger.error(f"invalid --input: {input} not exists")
  140. exit(-1)
  141. if mask_dir and not mask_dir.exists():
  142. logger.error(f"invalid --mask-dir: {mask_dir} not exists")
  143. exit(-1)
  144. if input and input.is_dir() and not output_dir:
  145. logger.error(
  146. "invalid --output-dir: --output-dir must be set when --input is a directory"
  147. )
  148. exit(-1)
  149. if output_dir:
  150. output_dir = output_dir.expanduser().absolute()
  151. logger.info(f"Image will be saved to {output_dir}")
  152. if not output_dir.exists():
  153. logger.info(f"Create output directory {output_dir}")
  154. output_dir.mkdir(parents=True)
  155. if mask_dir:
  156. mask_dir = mask_dir.expanduser().absolute()
  157. model_dir = model_dir.expanduser().absolute()
  158. if local_files_only:
  159. os.environ["TRANSFORMERS_OFFLINE"] = "1"
  160. os.environ["HF_HUB_OFFLINE"] = "1"
  161. from sorawm.iopaint.download import cli_download_model, scan_models
  162. scanned_models = scan_models()
  163. if model not in [it.name for it in scanned_models]:
  164. logger.info(f"{model} not found in {model_dir}, try to downloading")
  165. cli_download_model(model)
  166. from sorawm.iopaint.api import Api
  167. from sorawm.iopaint.schema import ApiConfig
  168. @asynccontextmanager
  169. async def lifespan(app: FastAPI):
  170. if inbrowser:
  171. webbrowser.open(f"http://localhost:{port}", new=0, autoraise=True)
  172. yield
  173. app = FastAPI(lifespan=lifespan)
  174. api_config = ApiConfig(
  175. host=host,
  176. port=port,
  177. inbrowser=inbrowser,
  178. model=model,
  179. no_half=no_half,
  180. low_mem=low_mem,
  181. cpu_offload=cpu_offload,
  182. disable_nsfw_checker=disable_nsfw_checker,
  183. local_files_only=local_files_only,
  184. cpu_textencoder=cpu_textencoder if device == Device.cuda else False,
  185. device=device,
  186. input=input,
  187. mask_dir=mask_dir,
  188. output_dir=output_dir,
  189. quality=quality,
  190. enable_interactive_seg=enable_interactive_seg,
  191. interactive_seg_model=interactive_seg_model,
  192. interactive_seg_device=interactive_seg_device,
  193. enable_remove_bg=enable_remove_bg,
  194. remove_bg_device=remove_bg_device,
  195. remove_bg_model=remove_bg_model,
  196. enable_anime_seg=enable_anime_seg,
  197. enable_realesrgan=enable_realesrgan,
  198. realesrgan_device=realesrgan_device,
  199. realesrgan_model=realesrgan_model,
  200. enable_gfpgan=enable_gfpgan,
  201. gfpgan_device=gfpgan_device,
  202. enable_restoreformer=enable_restoreformer,
  203. restoreformer_device=restoreformer_device,
  204. )
  205. print(api_config.model_dump_json(indent=4))
  206. api = Api(app, api_config)
  207. api.launch()
  208. @typer_app.command(help="Start IOPaint web config page")
  209. def start_web_config(
  210. config_file: Path = Option("config.json"),
  211. ):
  212. dump_environment_info()
  213. from sorawm.iopaint.web_config import main
  214. main(config_file)