web_config.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. import json
  2. import mimetypes
  3. import os
  4. from pathlib import Path
  5. # fix for windows mimetypes registry entries being borked
  6. # see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
  7. mimetypes.add_type("application/javascript", ".js")
  8. mimetypes.add_type("text/css", ".css")
  9. from sorawm.iopaint.schema import (
  10. ApiConfig,
  11. Device,
  12. InteractiveSegModel,
  13. RealESRGANModel,
  14. RemoveBGModel,
  15. )
  16. os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
  17. from datetime import datetime
  18. from json import JSONDecodeError
  19. import gradio as gr
  20. from loguru import logger
  21. from sorawm.iopaint.const import *
  22. from sorawm.iopaint.download import scan_models
  23. _config_file: Path = None
  24. default_configs = dict(
  25. host="127.0.0.1",
  26. port=8080,
  27. inbrowser=True,
  28. model=DEFAULT_MODEL,
  29. model_dir=DEFAULT_MODEL_DIR,
  30. no_half=False,
  31. low_mem=False,
  32. cpu_offload=False,
  33. disable_nsfw_checker=False,
  34. local_files_only=False,
  35. cpu_textencoder=False,
  36. device=Device.cuda,
  37. input=None,
  38. mask_dir=None,
  39. output_dir=None,
  40. quality=95,
  41. enable_interactive_seg=False,
  42. interactive_seg_model=InteractiveSegModel.sam2_1_tiny,
  43. interactive_seg_device=Device.cpu,
  44. enable_remove_bg=False,
  45. remove_bg_device=Device.cpu,
  46. remove_bg_model=RemoveBGModel.briaai_rmbg_1_4,
  47. enable_anime_seg=False,
  48. enable_realesrgan=False,
  49. realesrgan_device=Device.cpu,
  50. realesrgan_model=RealESRGANModel.realesr_general_x4v3,
  51. enable_gfpgan=False,
  52. gfpgan_device=Device.cpu,
  53. enable_restoreformer=False,
  54. restoreformer_device=Device.cpu,
  55. )
  56. class WebConfig(ApiConfig):
  57. model_dir: str = DEFAULT_MODEL_DIR
  58. def load_config(p: Path) -> WebConfig:
  59. if p.exists():
  60. with open(p, "r", encoding="utf-8") as f:
  61. try:
  62. return WebConfig(**{**default_configs, **json.load(f)})
  63. except JSONDecodeError:
  64. print("Load config file failed, using default configs")
  65. return WebConfig(**default_configs)
  66. else:
  67. return WebConfig(**default_configs)
  68. def save_config(
  69. host,
  70. port,
  71. model,
  72. model_dir,
  73. no_half,
  74. low_mem,
  75. cpu_offload,
  76. disable_nsfw_checker,
  77. local_files_only,
  78. cpu_textencoder,
  79. device,
  80. input,
  81. mask_dir,
  82. output_dir,
  83. quality,
  84. enable_interactive_seg,
  85. interactive_seg_model,
  86. interactive_seg_device,
  87. enable_remove_bg,
  88. remove_bg_device,
  89. remove_bg_model,
  90. enable_anime_seg,
  91. enable_realesrgan,
  92. realesrgan_device,
  93. realesrgan_model,
  94. enable_gfpgan,
  95. gfpgan_device,
  96. enable_restoreformer,
  97. restoreformer_device,
  98. inbrowser,
  99. ):
  100. config = WebConfig(**locals())
  101. if str(config.input) == ".":
  102. config.input = None
  103. if str(config.output_dir) == ".":
  104. config.output_dir = None
  105. if str(config.mask_dir) == ".":
  106. config.mask_dir = None
  107. config.model = config.model.strip()
  108. print(config.model_dump_json(indent=4))
  109. if config.input and not os.path.exists(config.input):
  110. return "[Error] Input file or directory does not exist"
  111. current_time = datetime.now().strftime("%H:%M:%S")
  112. msg = f"[{current_time}] Successful save config to: {str(_config_file.absolute())}"
  113. logger.info(msg)
  114. try:
  115. with open(_config_file, "w", encoding="utf-8") as f:
  116. f.write(config.model_dump_json(indent=4))
  117. except Exception as e:
  118. return f"Save configure file failed: {str(e)}"
  119. return msg
  120. def change_current_model(new_model):
  121. return new_model
  122. def main(config_file: Path):
  123. global _config_file
  124. _config_file = config_file
  125. init_config = load_config(config_file)
  126. downloaded_models = [it.name for it in scan_models()]
  127. with gr.Blocks() as demo:
  128. with gr.Row():
  129. with gr.Column():
  130. gr.Textbox(config_file, label="Config file", interactive=False)
  131. with gr.Column():
  132. save_btn = gr.Button(value="Save configurations")
  133. message = gr.HTML()
  134. with gr.Tabs():
  135. with gr.Tab("Common"):
  136. with gr.Row():
  137. host = gr.Textbox(init_config.host, label="Host")
  138. port = gr.Number(init_config.port, label="Port", precision=0)
  139. inbrowser = gr.Checkbox(init_config.inbrowser, label=INBROWSER_HELP)
  140. with gr.Row():
  141. recommend_model = gr.Dropdown(
  142. ["lama", "mat", "migan"] + DIFFUSION_MODELS,
  143. label="Recommended Models",
  144. )
  145. downloaded_model = gr.Dropdown(
  146. downloaded_models, label="Downloaded Models"
  147. )
  148. with gr.Column():
  149. model = gr.Textbox(
  150. init_config.model,
  151. label="Current Model. Model will be automatically downloaded. "
  152. "You can select a model in Recommended Models or Downloaded Models or manually enter the SD/SDXL model ID from HuggingFace, for example, runwayml/stable-diffusion-inpainting.",
  153. )
  154. device = gr.Radio(
  155. Device.values(), label="Device", value=init_config.device
  156. )
  157. quality = gr.Slider(
  158. value=95,
  159. label=f"Image Quality ({QUALITY_HELP})",
  160. minimum=75,
  161. maximum=100,
  162. step=1,
  163. )
  164. no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}")
  165. cpu_offload = gr.Checkbox(
  166. init_config.cpu_offload, label=f"{CPU_OFFLOAD_HELP}"
  167. )
  168. low_mem = gr.Checkbox(init_config.low_mem, label=f"{LOW_MEM_HELP}")
  169. cpu_textencoder = gr.Checkbox(
  170. init_config.cpu_textencoder, label=f"{CPU_TEXTENCODER_HELP}"
  171. )
  172. disable_nsfw_checker = gr.Checkbox(
  173. init_config.disable_nsfw_checker, label=f"{DISABLE_NSFW_HELP}"
  174. )
  175. local_files_only = gr.Checkbox(
  176. init_config.local_files_only, label=f"{LOCAL_FILES_ONLY_HELP}"
  177. )
  178. with gr.Column():
  179. model_dir = gr.Textbox(
  180. init_config.model_dir, label=f"{MODEL_DIR_HELP}"
  181. )
  182. input = gr.Textbox(
  183. init_config.input,
  184. label=f"Input file or directory. {INPUT_HELP}",
  185. )
  186. output_dir = gr.Textbox(
  187. init_config.output_dir,
  188. label=f"Output directory. {OUTPUT_DIR_HELP}",
  189. )
  190. mask_dir = gr.Textbox(
  191. init_config.mask_dir,
  192. label=f"Mask directory. {MASK_DIR_HELP}",
  193. )
  194. with gr.Tab("Plugins"):
  195. with gr.Row():
  196. enable_interactive_seg = gr.Checkbox(
  197. init_config.enable_interactive_seg, label=INTERACTIVE_SEG_HELP
  198. )
  199. interactive_seg_model = gr.Radio(
  200. InteractiveSegModel.values(),
  201. label=f"Segment Anything models. {INTERACTIVE_SEG_MODEL_HELP}",
  202. value=init_config.interactive_seg_model,
  203. )
  204. interactive_seg_device = gr.Radio(
  205. Device.values(),
  206. label="Segment Anything Device",
  207. value=init_config.interactive_seg_device,
  208. )
  209. with gr.Row():
  210. enable_remove_bg = gr.Checkbox(
  211. init_config.enable_remove_bg, label=REMOVE_BG_HELP
  212. )
  213. remove_bg_device = gr.Radio(
  214. Device.values(),
  215. label=REMOVE_BG_DEVICE_HELP,
  216. value=init_config.remove_bg_device,
  217. )
  218. remove_bg_model = gr.Radio(
  219. RemoveBGModel.values(),
  220. label="Remove bg model",
  221. value=init_config.remove_bg_model,
  222. )
  223. with gr.Row():
  224. enable_anime_seg = gr.Checkbox(
  225. init_config.enable_anime_seg, label=ANIMESEG_HELP
  226. )
  227. with gr.Row():
  228. enable_realesrgan = gr.Checkbox(
  229. init_config.enable_realesrgan, label=REALESRGAN_HELP
  230. )
  231. realesrgan_device = gr.Radio(
  232. Device.values(),
  233. label="RealESRGAN Device",
  234. value=init_config.realesrgan_device,
  235. )
  236. realesrgan_model = gr.Radio(
  237. RealESRGANModel.values(),
  238. label="RealESRGAN model",
  239. value=init_config.realesrgan_model,
  240. )
  241. with gr.Row():
  242. enable_gfpgan = gr.Checkbox(
  243. init_config.enable_gfpgan, label=GFPGAN_HELP
  244. )
  245. gfpgan_device = gr.Radio(
  246. Device.values(),
  247. label="GFPGAN Device",
  248. value=init_config.gfpgan_device,
  249. )
  250. with gr.Row():
  251. enable_restoreformer = gr.Checkbox(
  252. init_config.enable_restoreformer, label=RESTOREFORMER_HELP
  253. )
  254. restoreformer_device = gr.Radio(
  255. Device.values(),
  256. label="RestoreFormer Device",
  257. value=init_config.restoreformer_device,
  258. )
  259. downloaded_model.change(change_current_model, [downloaded_model], model)
  260. recommend_model.change(change_current_model, [recommend_model], model)
  261. save_btn.click(
  262. save_config,
  263. [
  264. host,
  265. port,
  266. model,
  267. model_dir,
  268. no_half,
  269. low_mem,
  270. cpu_offload,
  271. disable_nsfw_checker,
  272. local_files_only,
  273. cpu_textencoder,
  274. device,
  275. input,
  276. mask_dir,
  277. output_dir,
  278. quality,
  279. enable_interactive_seg,
  280. interactive_seg_model,
  281. interactive_seg_device,
  282. enable_remove_bg,
  283. remove_bg_device,
  284. remove_bg_model,
  285. enable_anime_seg,
  286. enable_realesrgan,
  287. realesrgan_device,
  288. realesrgan_model,
  289. enable_gfpgan,
  290. gfpgan_device,
  291. enable_restoreformer,
  292. restoreformer_device,
  293. inbrowser,
  294. ],
  295. message,
  296. )
  297. demo.launch(inbrowser=True, show_api=False)