api.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. import asyncio
  2. import os
  3. import threading
  4. import time
  5. import traceback
  6. from pathlib import Path
  7. from typing import Dict, List, Optional
  8. import cv2
  9. import numpy as np
  10. import socketio
  11. import torch
  12. try:
  13. torch._C._jit_override_can_fuse_on_cpu(False)
  14. torch._C._jit_override_can_fuse_on_gpu(False)
  15. torch._C._jit_set_texpr_fuser_enabled(False)
  16. torch._C._jit_set_nvfuser_enabled(False)
  17. torch._C._jit_set_profiling_mode(False)
  18. except:
  19. pass
  20. import uvicorn
  21. from fastapi import APIRouter, FastAPI, Request, UploadFile
  22. from fastapi.encoders import jsonable_encoder
  23. from fastapi.exceptions import HTTPException
  24. from fastapi.middleware.cors import CORSMiddleware
  25. from fastapi.responses import FileResponse, JSONResponse, Response
  26. from fastapi.staticfiles import StaticFiles
  27. from loguru import logger
  28. from PIL import Image
  29. from socketio import AsyncServer
  30. from sorawm.iopaint.file_manager import FileManager
  31. from sorawm.iopaint.helper import (
  32. adjust_mask,
  33. concat_alpha_channel,
  34. decode_base64_to_image,
  35. gen_frontend_mask,
  36. load_img,
  37. numpy_to_bytes,
  38. pil_to_bytes,
  39. )
  40. from sorawm.iopaint.model.utils import torch_gc
  41. from sorawm.iopaint.model_manager import ModelManager
  42. from sorawm.iopaint.plugins import InteractiveSeg, RealESRGANUpscaler, build_plugins
  43. from sorawm.iopaint.plugins.base_plugin import BasePlugin
  44. from sorawm.iopaint.plugins.remove_bg import RemoveBG
  45. from sorawm.iopaint.schema import (
  46. AdjustMaskRequest,
  47. ApiConfig,
  48. GenInfoResponse,
  49. InpaintRequest,
  50. InteractiveSegModel,
  51. ModelInfo,
  52. PluginInfo,
  53. RealESRGANModel,
  54. RemoveBGModel,
  55. RunPluginRequest,
  56. SDSampler,
  57. ServerConfigResponse,
  58. SwitchModelRequest,
  59. SwitchPluginModelRequest,
  60. )
  61. CURRENT_DIR = Path(__file__).parent.absolute().resolve()
  62. WEB_APP_DIR = CURRENT_DIR / "web_app"
  63. def api_middleware(app: FastAPI):
  64. rich_available = False
  65. try:
  66. if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None:
  67. import anyio # importing just so it can be placed on silent list
  68. import starlette # importing just so it can be placed on silent list
  69. from rich.console import Console
  70. console = Console()
  71. rich_available = True
  72. except Exception:
  73. pass
  74. def handle_exception(request: Request, e: Exception):
  75. err = {
  76. "error": type(e).__name__,
  77. "detail": vars(e).get("detail", ""),
  78. "body": vars(e).get("body", ""),
  79. "errors": str(e),
  80. }
  81. if not isinstance(
  82. e, HTTPException
  83. ): # do not print backtrace on known httpexceptions
  84. message = f"API error: {request.method}: {request.url} {err}"
  85. if rich_available:
  86. print(message)
  87. console.print_exception(
  88. show_locals=True,
  89. max_frames=2,
  90. extra_lines=1,
  91. suppress=[anyio, starlette],
  92. word_wrap=False,
  93. width=min([console.width, 200]),
  94. )
  95. else:
  96. traceback.print_exc()
  97. return JSONResponse(
  98. status_code=vars(e).get("status_code", 500), content=jsonable_encoder(err)
  99. )
  100. @app.middleware("http")
  101. async def exception_handling(request: Request, call_next):
  102. try:
  103. return await call_next(request)
  104. except Exception as e:
  105. return handle_exception(request, e)
  106. @app.exception_handler(Exception)
  107. async def fastapi_exception_handler(request: Request, e: Exception):
  108. return handle_exception(request, e)
  109. @app.exception_handler(HTTPException)
  110. async def http_exception_handler(request: Request, e: HTTPException):
  111. return handle_exception(request, e)
  112. cors_options = {
  113. "allow_methods": ["*"],
  114. "allow_headers": ["*"],
  115. "allow_origins": ["*"],
  116. "allow_credentials": True,
  117. "expose_headers": ["X-Seed"],
  118. }
  119. app.add_middleware(CORSMiddleware, **cors_options)
  120. global_sio: AsyncServer = None
  121. def diffuser_callback(pipe, step: int, timestep: int, callback_kwargs: Dict = {}):
  122. # self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict
  123. # logger.info(f"diffusion callback: step={step}, timestep={timestep}")
  124. # We use asyncio loos for task processing. Perhaps in the future, we can add a processing queue similar to InvokeAI,
  125. # but for now let's just start a separate event loop. It shouldn't make a difference for single person use
  126. asyncio.run(global_sio.emit("diffusion_progress", {"step": step}))
  127. return {}
  128. class Api:
  129. def __init__(self, app: FastAPI, config: ApiConfig):
  130. self.app = app
  131. self.config = config
  132. self.router = APIRouter()
  133. self.queue_lock = threading.Lock()
  134. api_middleware(self.app)
  135. self.file_manager = self._build_file_manager()
  136. self.plugins = self._build_plugins()
  137. self.model_manager = self._build_model_manager()
  138. # fmt: off
  139. self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse)
  140. self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"],
  141. response_model=ServerConfigResponse)
  142. self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo)
  143. self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
  144. self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
  145. self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
  146. self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"])
  147. self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
  148. self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
  149. self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
  150. self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"])
  151. self.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"])
  152. self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
  153. # fmt: on
  154. global global_sio
  155. self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
  156. self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app)
  157. self.app.mount("/ws", self.combined_asgi_app)
  158. global_sio = self.sio
  159. def add_api_route(self, path: str, endpoint, **kwargs):
  160. return self.app.add_api_route(path, endpoint, **kwargs)
  161. def api_save_image(self, file: UploadFile):
  162. # Sanitize filename to prevent path traversal
  163. safe_filename = Path(file.filename).name # Get just the filename component
  164. # Construct the full path within output_dir
  165. output_path = self.config.output_dir / safe_filename
  166. # Ensure output directory exists
  167. if not self.config.output_dir or not self.config.output_dir.exists():
  168. raise HTTPException(
  169. status_code=400,
  170. detail="Output directory not configured or doesn't exist",
  171. )
  172. # Read and write the file
  173. origin_image_bytes = file.file.read()
  174. with open(output_path, "wb") as fw:
  175. fw.write(origin_image_bytes)
  176. def api_current_model(self) -> ModelInfo:
  177. return self.model_manager.current_model
  178. def api_switch_model(self, req: SwitchModelRequest) -> ModelInfo:
  179. if req.name == self.model_manager.name:
  180. return self.model_manager.current_model
  181. self.model_manager.switch(req.name)
  182. return self.model_manager.current_model
  183. def api_switch_plugin_model(self, req: SwitchPluginModelRequest):
  184. if req.plugin_name in self.plugins:
  185. self.plugins[req.plugin_name].switch_model(req.model_name)
  186. if req.plugin_name == RemoveBG.name:
  187. self.config.remove_bg_model = req.model_name
  188. if req.plugin_name == RealESRGANUpscaler.name:
  189. self.config.realesrgan_model = req.model_name
  190. if req.plugin_name == InteractiveSeg.name:
  191. self.config.interactive_seg_model = req.model_name
  192. torch_gc()
  193. def api_server_config(self) -> ServerConfigResponse:
  194. plugins = []
  195. for it in self.plugins.values():
  196. plugins.append(
  197. PluginInfo(
  198. name=it.name,
  199. support_gen_image=it.support_gen_image,
  200. support_gen_mask=it.support_gen_mask,
  201. )
  202. )
  203. return ServerConfigResponse(
  204. plugins=plugins,
  205. modelInfos=self.model_manager.scan_models(),
  206. removeBGModel=self.config.remove_bg_model,
  207. removeBGModels=RemoveBGModel.values(),
  208. realesrganModel=self.config.realesrgan_model,
  209. realesrganModels=RealESRGANModel.values(),
  210. interactiveSegModel=self.config.interactive_seg_model,
  211. interactiveSegModels=InteractiveSegModel.values(),
  212. enableFileManager=self.file_manager is not None,
  213. enableAutoSaving=self.config.output_dir is not None,
  214. enableControlnet=self.model_manager.enable_controlnet,
  215. controlnetMethod=self.model_manager.controlnet_method,
  216. disableModelSwitch=False,
  217. isDesktop=False,
  218. samplers=self.api_samplers(),
  219. )
  220. def api_input_image(self) -> FileResponse:
  221. if self.config.input is None:
  222. raise HTTPException(status_code=200, detail="No input image configured")
  223. if self.config.input.is_file():
  224. return FileResponse(self.config.input)
  225. raise HTTPException(status_code=404, detail="Input image not found")
  226. def api_geninfo(self, file: UploadFile) -> GenInfoResponse:
  227. _, _, info = load_img(file.file.read(), return_info=True)
  228. parts = info.get("parameters", "").split("Negative prompt: ")
  229. prompt = parts[0].strip()
  230. negative_prompt = ""
  231. if len(parts) > 1:
  232. negative_prompt = parts[1].split("\n")[0].strip()
  233. return GenInfoResponse(prompt=prompt, negative_prompt=negative_prompt)
  234. def api_inpaint(self, req: InpaintRequest):
  235. image, alpha_channel, infos, ext = decode_base64_to_image(req.image)
  236. mask, _, _, _ = decode_base64_to_image(req.mask, gray=True)
  237. logger.info(f"image ext: {ext}")
  238. mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
  239. if image.shape[:2] != mask.shape[:2]:
  240. raise HTTPException(
  241. 400,
  242. detail=f"Image size({image.shape[:2]}) and mask size({mask.shape[:2]}) not match.",
  243. )
  244. start = time.time()
  245. rgb_np_img = self.model_manager(image, mask, req)
  246. logger.info(f"process time: {(time.time() - start) * 1000:.2f}ms")
  247. torch_gc()
  248. rgb_np_img = cv2.cvtColor(rgb_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB)
  249. rgb_res = concat_alpha_channel(rgb_np_img, alpha_channel)
  250. res_img_bytes = pil_to_bytes(
  251. Image.fromarray(rgb_res),
  252. ext=ext,
  253. quality=self.config.quality,
  254. infos=infos,
  255. )
  256. asyncio.run(self.sio.emit("diffusion_finish"))
  257. return Response(
  258. content=res_img_bytes,
  259. media_type=f"image/{ext}",
  260. headers={"X-Seed": str(req.sd_seed)},
  261. )
  262. def api_run_plugin_gen_image(self, req: RunPluginRequest):
  263. ext = "png"
  264. if req.name not in self.plugins:
  265. raise HTTPException(status_code=422, detail="Plugin not found")
  266. if not self.plugins[req.name].support_gen_image:
  267. raise HTTPException(
  268. status_code=422, detail="Plugin does not support output image"
  269. )
  270. rgb_np_img, alpha_channel, infos, _ = decode_base64_to_image(req.image)
  271. bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req)
  272. torch_gc()
  273. if bgr_or_rgba_np_img.shape[2] == 4:
  274. rgba_np_img = bgr_or_rgba_np_img
  275. else:
  276. rgba_np_img = cv2.cvtColor(bgr_or_rgba_np_img, cv2.COLOR_BGR2RGB)
  277. rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel)
  278. return Response(
  279. content=pil_to_bytes(
  280. Image.fromarray(rgba_np_img),
  281. ext=ext,
  282. quality=self.config.quality,
  283. infos=infos,
  284. ),
  285. media_type=f"image/{ext}",
  286. )
  287. def api_run_plugin_gen_mask(self, req: RunPluginRequest):
  288. if req.name not in self.plugins:
  289. raise HTTPException(status_code=422, detail="Plugin not found")
  290. if not self.plugins[req.name].support_gen_mask:
  291. raise HTTPException(
  292. status_code=422, detail="Plugin does not support output image"
  293. )
  294. rgb_np_img, _, _, _ = decode_base64_to_image(req.image)
  295. bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req)
  296. torch_gc()
  297. res_mask = gen_frontend_mask(bgr_or_gray_mask)
  298. return Response(
  299. content=numpy_to_bytes(res_mask, "png"),
  300. media_type="image/png",
  301. )
  302. def api_samplers(self) -> List[str]:
  303. return [member.value for member in SDSampler.__members__.values()]
  304. def api_adjust_mask(self, req: AdjustMaskRequest):
  305. mask, _, _, _ = decode_base64_to_image(req.mask, gray=True)
  306. mask = adjust_mask(mask, req.kernel_size, req.operate)
  307. return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png")
  308. def launch(self):
  309. self.app.include_router(self.router)
  310. uvicorn.run(
  311. self.combined_asgi_app,
  312. host=self.config.host,
  313. port=self.config.port,
  314. timeout_keep_alive=999999999,
  315. )
  316. def _build_file_manager(self) -> Optional[FileManager]:
  317. if self.config.input and self.config.input.is_dir():
  318. logger.info(
  319. f"Input is directory, initialize file manager {self.config.input}"
  320. )
  321. return FileManager(
  322. app=self.app,
  323. input_dir=self.config.input,
  324. mask_dir=self.config.mask_dir,
  325. output_dir=self.config.output_dir,
  326. )
  327. return None
  328. def _build_plugins(self) -> Dict[str, BasePlugin]:
  329. return build_plugins(
  330. self.config.enable_interactive_seg,
  331. self.config.interactive_seg_model,
  332. self.config.interactive_seg_device,
  333. self.config.enable_remove_bg,
  334. self.config.remove_bg_device,
  335. self.config.remove_bg_model,
  336. self.config.enable_anime_seg,
  337. self.config.enable_realesrgan,
  338. self.config.realesrgan_device,
  339. self.config.realesrgan_model,
  340. self.config.enable_gfpgan,
  341. self.config.gfpgan_device,
  342. self.config.enable_restoreformer,
  343. self.config.restoreformer_device,
  344. self.config.no_half,
  345. )
  346. def _build_model_manager(self):
  347. return ModelManager(
  348. name=self.config.model,
  349. device=torch.device(self.config.device),
  350. no_half=self.config.no_half,
  351. low_mem=self.config.low_mem,
  352. disable_nsfw=self.config.disable_nsfw_checker,
  353. sd_cpu_textencoder=self.config.cpu_textencoder,
  354. local_files_only=self.config.local_files_only,
  355. cpu_offload=self.config.cpu_offload,
  356. callback=diffuser_callback,
  357. )