schema.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. import random
  2. from enum import Enum
  3. from pathlib import Path
  4. from typing import List, Literal, Optional
  5. from loguru import logger
  6. from pydantic import BaseModel, Field, computed_field, model_validator
  7. from sorawm.iopaint.const import (
  8. ANYTEXT_NAME,
  9. INSTRUCT_PIX2PIX_NAME,
  10. KANDINSKY22_NAME,
  11. POWERPAINT_NAME,
  12. SD2_CONTROLNET_CHOICES,
  13. SD_BRUSHNET_CHOICES,
  14. SD_CONTROLNET_CHOICES,
  15. SDXL_BRUSHNET_CHOICES,
  16. SDXL_CONTROLNET_CHOICES,
  17. )
  18. class ModelType(str, Enum):
  19. INPAINT = "inpaint" # LaMa, MAT...
  20. DIFFUSERS_SD = "diffusers_sd"
  21. DIFFUSERS_SD_INPAINT = "diffusers_sd_inpaint"
  22. DIFFUSERS_SDXL = "diffusers_sdxl"
  23. DIFFUSERS_SDXL_INPAINT = "diffusers_sdxl_inpaint"
  24. DIFFUSERS_OTHER = "diffusers_other"
  25. class ModelInfo(BaseModel):
  26. name: str
  27. path: str
  28. model_type: ModelType
  29. is_single_file_diffusers: bool = False
  30. @computed_field
  31. @property
  32. def need_prompt(self) -> bool:
  33. return self.model_type in [
  34. ModelType.DIFFUSERS_SD,
  35. ModelType.DIFFUSERS_SDXL,
  36. ModelType.DIFFUSERS_SD_INPAINT,
  37. ModelType.DIFFUSERS_SDXL_INPAINT,
  38. ] or self.name in [
  39. INSTRUCT_PIX2PIX_NAME,
  40. KANDINSKY22_NAME,
  41. POWERPAINT_NAME,
  42. ANYTEXT_NAME,
  43. ]
  44. @computed_field
  45. @property
  46. def controlnets(self) -> List[str]:
  47. if self.model_type in [
  48. ModelType.DIFFUSERS_SDXL,
  49. ModelType.DIFFUSERS_SDXL_INPAINT,
  50. ]:
  51. return SDXL_CONTROLNET_CHOICES
  52. if self.model_type in [ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT]:
  53. if "sd2" in self.name.lower():
  54. return SD2_CONTROLNET_CHOICES
  55. else:
  56. return SD_CONTROLNET_CHOICES
  57. if self.name == POWERPAINT_NAME:
  58. return SD_CONTROLNET_CHOICES
  59. return []
  60. @computed_field
  61. @property
  62. def brushnets(self) -> List[str]:
  63. if self.model_type in [ModelType.DIFFUSERS_SD]:
  64. return SD_BRUSHNET_CHOICES
  65. if self.model_type in [ModelType.DIFFUSERS_SDXL]:
  66. return SDXL_BRUSHNET_CHOICES
  67. return []
  68. @computed_field
  69. @property
  70. def support_strength(self) -> bool:
  71. return self.model_type in [
  72. ModelType.DIFFUSERS_SD,
  73. ModelType.DIFFUSERS_SDXL,
  74. ModelType.DIFFUSERS_SD_INPAINT,
  75. ModelType.DIFFUSERS_SDXL_INPAINT,
  76. ] or self.name in [POWERPAINT_NAME, ANYTEXT_NAME]
  77. @computed_field
  78. @property
  79. def support_outpainting(self) -> bool:
  80. return self.model_type in [
  81. ModelType.DIFFUSERS_SD,
  82. ModelType.DIFFUSERS_SDXL,
  83. ModelType.DIFFUSERS_SD_INPAINT,
  84. ModelType.DIFFUSERS_SDXL_INPAINT,
  85. ] or self.name in [KANDINSKY22_NAME, POWERPAINT_NAME]
  86. @computed_field
  87. @property
  88. def support_lcm_lora(self) -> bool:
  89. return self.model_type in [
  90. ModelType.DIFFUSERS_SD,
  91. ModelType.DIFFUSERS_SDXL,
  92. ModelType.DIFFUSERS_SD_INPAINT,
  93. ModelType.DIFFUSERS_SDXL_INPAINT,
  94. ]
  95. @computed_field
  96. @property
  97. def support_controlnet(self) -> bool:
  98. return self.model_type in [
  99. ModelType.DIFFUSERS_SD,
  100. ModelType.DIFFUSERS_SDXL,
  101. ModelType.DIFFUSERS_SD_INPAINT,
  102. ModelType.DIFFUSERS_SDXL_INPAINT,
  103. ]
  104. @computed_field
  105. @property
  106. def support_brushnet(self) -> bool:
  107. return self.model_type in [
  108. ModelType.DIFFUSERS_SD,
  109. ModelType.DIFFUSERS_SDXL,
  110. ]
  111. @computed_field
  112. @property
  113. def support_powerpaint_v2(self) -> bool:
  114. return (
  115. self.model_type
  116. in [
  117. ModelType.DIFFUSERS_SD,
  118. ]
  119. and self.name != POWERPAINT_NAME
  120. )
  121. class Choices(str, Enum):
  122. @classmethod
  123. def values(cls):
  124. return [member.value for member in cls]
  125. class RealESRGANModel(Choices):
  126. realesr_general_x4v3 = "realesr-general-x4v3"
  127. RealESRGAN_x4plus = "RealESRGAN_x4plus"
  128. RealESRGAN_x4plus_anime_6B = "RealESRGAN_x4plus_anime_6B"
  129. class RemoveBGModel(Choices):
  130. briaai_rmbg_1_4 = "briaai/RMBG-1.4"
  131. briaai_rmbg_2_0 = "briaai/RMBG-2.0"
  132. # models from https://github.com/danielgatis/rembg
  133. u2net = "u2net"
  134. u2netp = "u2netp"
  135. u2net_human_seg = "u2net_human_seg"
  136. u2net_cloth_seg = "u2net_cloth_seg"
  137. silueta = "silueta"
  138. isnet_general_use = "isnet-general-use"
  139. birefnet_general = "birefnet-general"
  140. birefnet_general_lite = "birefnet-general-lite"
  141. birefnet_portrait = "birefnet-portrait"
  142. birefnet_dis = "birefnet-dis"
  143. birefnet_hrsod = "birefnet-hrsod"
  144. birefnet_cod = "birefnet-cod"
  145. birefnet_massive = "birefnet-massive"
  146. class Device(Choices):
  147. cpu = "cpu"
  148. cuda = "cuda"
  149. mps = "mps"
  150. class InteractiveSegModel(Choices):
  151. vit_b = "vit_b"
  152. vit_l = "vit_l"
  153. vit_h = "vit_h"
  154. sam_hq_vit_b = "sam_hq_vit_b"
  155. sam_hq_vit_l = "sam_hq_vit_l"
  156. sam_hq_vit_h = "sam_hq_vit_h"
  157. mobile_sam = "mobile_sam"
  158. sam2_tiny = "sam2_tiny"
  159. sam2_small = "sam2_small"
  160. sam2_base = "sam2_base"
  161. sam2_large = "sam2_large"
  162. sam2_1_tiny = "sam2_1_tiny"
  163. sam2_1_small = "sam2_1_small"
  164. sam2_1_base = "sam2_1_base"
  165. sam2_1_large = "sam2_1_large"
  166. class PluginInfo(BaseModel):
  167. name: str
  168. support_gen_image: bool = False
  169. support_gen_mask: bool = False
  170. class CV2Flag(str, Enum):
  171. INPAINT_NS = "INPAINT_NS"
  172. INPAINT_TELEA = "INPAINT_TELEA"
  173. class HDStrategy(str, Enum):
  174. # Use original image size
  175. ORIGINAL = "Original"
  176. # Resize the longer side of the image to a specific size(hd_strategy_resize_limit),
  177. # then do inpainting on the resized image. Finally, resize the inpainting result to the original size.
  178. # The area outside the mask will not lose quality.
  179. RESIZE = "Resize"
  180. # Crop masking area(with a margin controlled by hd_strategy_crop_margin) from the original image to do inpainting
  181. CROP = "Crop"
  182. class LDMSampler(str, Enum):
  183. ddim = "ddim"
  184. plms = "plms"
  185. class SDSampler(str, Enum):
  186. dpm_plus_plus_2m = "DPM++ 2M"
  187. dpm_plus_plus_2m_karras = "DPM++ 2M Karras"
  188. dpm_plus_plus_2m_sde = "DPM++ 2M SDE"
  189. dpm_plus_plus_2m_sde_karras = "DPM++ 2M SDE Karras"
  190. dpm_plus_plus_sde = "DPM++ SDE"
  191. dpm_plus_plus_sde_karras = "DPM++ SDE Karras"
  192. dpm2 = "DPM2"
  193. dpm2_karras = "DPM2 Karras"
  194. dpm2_a = "DPM2 a"
  195. dpm2_a_karras = "DPM2 a Karras"
  196. euler = "Euler"
  197. euler_a = "Euler a"
  198. heun = "Heun"
  199. lms = "LMS"
  200. lms_karras = "LMS Karras"
  201. ddim = "DDIM"
  202. pndm = "PNDM"
  203. uni_pc = "UniPC"
  204. lcm = "LCM"
  205. class PowerPaintTask(Choices):
  206. text_guided = "text-guided"
  207. context_aware = "context-aware"
  208. shape_guided = "shape-guided"
  209. object_remove = "object-remove"
  210. outpainting = "outpainting"
  211. class ApiConfig(BaseModel):
  212. host: str
  213. port: int
  214. inbrowser: bool
  215. model: str
  216. no_half: bool
  217. low_mem: bool
  218. cpu_offload: bool
  219. disable_nsfw_checker: bool
  220. local_files_only: bool
  221. cpu_textencoder: bool
  222. device: Device
  223. input: Optional[Path]
  224. mask_dir: Optional[Path]
  225. output_dir: Optional[Path]
  226. quality: int
  227. enable_interactive_seg: bool
  228. interactive_seg_model: InteractiveSegModel
  229. interactive_seg_device: Device
  230. enable_remove_bg: bool
  231. remove_bg_device: Device
  232. remove_bg_model: str
  233. enable_anime_seg: bool
  234. enable_realesrgan: bool
  235. realesrgan_device: Device
  236. realesrgan_model: RealESRGANModel
  237. enable_gfpgan: bool
  238. gfpgan_device: Device
  239. enable_restoreformer: bool
  240. restoreformer_device: Device
  241. class InpaintRequest(BaseModel):
  242. image: Optional[str] = Field(None, description="base64 encoded image")
  243. mask: Optional[str] = Field(None, description="base64 encoded mask")
  244. ldm_steps: int = Field(20, description="Steps for ldm model.")
  245. ldm_sampler: str = Field(LDMSampler.plms, description="Sampler for ldm model.")
  246. zits_wireframe: bool = Field(True, description="Enable wireframe for zits model.")
  247. hd_strategy: str = Field(
  248. HDStrategy.CROP,
  249. description="Different way to preprocess image, only used by erase models(e.g. lama/mat)",
  250. )
  251. hd_strategy_crop_trigger_size: int = Field(
  252. 800,
  253. description="Crop trigger size for hd_strategy=CROP, if the longer side of the image is larger than this value, use crop strategy",
  254. )
  255. hd_strategy_crop_margin: int = Field(
  256. 128, description="Crop margin for hd_strategy=CROP"
  257. )
  258. hd_strategy_resize_limit: int = Field(
  259. 1280, description="Resize limit for hd_strategy=RESIZE"
  260. )
  261. prompt: str = Field("", description="Prompt for diffusion models.")
  262. negative_prompt: str = Field(
  263. "", description="Negative prompt for diffusion models."
  264. )
  265. use_croper: bool = Field(
  266. False, description="Crop image before doing diffusion inpainting"
  267. )
  268. croper_x: int = Field(0, description="Crop x for croper")
  269. croper_y: int = Field(0, description="Crop y for croper")
  270. croper_height: int = Field(512, description="Crop height for croper")
  271. croper_width: int = Field(512, description="Crop width for croper")
  272. use_extender: bool = Field(
  273. False, description="Extend image before doing sd outpainting"
  274. )
  275. extender_x: int = Field(0, description="Extend x for extender")
  276. extender_y: int = Field(0, description="Extend y for extender")
  277. extender_height: int = Field(640, description="Extend height for extender")
  278. extender_width: int = Field(640, description="Extend width for extender")
  279. sd_scale: float = Field(
  280. 1.0,
  281. description="Resize the image before doing sd inpainting, the area outside the mask will not lose quality.",
  282. gt=0.0,
  283. le=1.0,
  284. )
  285. sd_mask_blur: int = Field(
  286. 11,
  287. description="Blur the edge of mask area. The higher the number the smoother blend with the original image",
  288. )
  289. sd_strength: float = Field(
  290. 1.0,
  291. description="Strength is a measure of how much noise is added to the base image, which influences how similar the output is to the base image. Higher value means more noise and more different from the base image",
  292. le=1.0,
  293. )
  294. sd_steps: int = Field(
  295. 50,
  296. description="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.",
  297. )
  298. sd_guidance_scale: float = Field(
  299. 7.5,
  300. description="Higher guidance scale encourages to generate images that are closely linked to the text prompt, usually at the expense of lower image quality.",
  301. )
  302. sd_sampler: str = Field(
  303. SDSampler.uni_pc, description="Sampler for diffusion model."
  304. )
  305. sd_seed: int = Field(
  306. 42,
  307. description="Seed for diffusion model. -1 mean random seed",
  308. validate_default=True,
  309. )
  310. sd_match_histograms: bool = Field(
  311. False,
  312. description="Match histograms between inpainting area and original image.",
  313. )
  314. sd_outpainting_softness: float = Field(20.0)
  315. sd_outpainting_space: float = Field(20.0)
  316. sd_lcm_lora: bool = Field(
  317. False,
  318. description="Enable lcm-lora mode. https://huggingface.co/docs/diffusers/main/en/using-diffusers/inference_with_lcm#texttoimage",
  319. )
  320. sd_keep_unmasked_area: bool = Field(
  321. True, description="Keep unmasked area unchanged"
  322. )
  323. cv2_flag: CV2Flag = Field(
  324. CV2Flag.INPAINT_NS,
  325. description="Flag for opencv inpainting: https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07",
  326. )
  327. cv2_radius: int = Field(
  328. 4,
  329. description="Radius of a circular neighborhood of each point inpainted that is considered by the algorithm",
  330. )
  331. # Paint by Example
  332. paint_by_example_example_image: Optional[str] = Field(
  333. None, description="Base64 encoded example image for paint by example model"
  334. )
  335. # InstructPix2Pix
  336. p2p_image_guidance_scale: float = Field(1.5, description="Image guidance scale")
  337. # ControlNet
  338. enable_controlnet: bool = Field(False, description="Enable controlnet")
  339. controlnet_conditioning_scale: float = Field(
  340. 0.4, description="Conditioning scale", ge=0.0, le=1.0
  341. )
  342. controlnet_method: str = Field(
  343. "lllyasviel/control_v11p_sd15_canny", description="Controlnet method"
  344. )
  345. # BrushNet
  346. enable_brushnet: bool = Field(False, description="Enable brushnet")
  347. brushnet_method: str = Field(SD_BRUSHNET_CHOICES[0], description="Brushnet method")
  348. brushnet_conditioning_scale: float = Field(
  349. 1.0, description="brushnet conditioning scale", ge=0.0, le=1.0
  350. )
  351. # PowerPaint
  352. enable_powerpaint_v2: bool = Field(False, description="Enable PowerPaint v2")
  353. powerpaint_task: PowerPaintTask = Field(
  354. PowerPaintTask.text_guided, description="PowerPaint task"
  355. )
  356. fitting_degree: float = Field(
  357. 1.0,
  358. description="Control the fitting degree of the generated objects to the mask shape.",
  359. gt=0.0,
  360. le=1.0,
  361. )
  362. @model_validator(mode="after")
  363. def validate_field(cls, values: "InpaintRequest"):
  364. if values.sd_seed == -1:
  365. values.sd_seed = random.randint(1, 99999999)
  366. logger.info(f"Generate random seed: {values.sd_seed}")
  367. if values.use_extender and values.enable_controlnet:
  368. logger.info("Extender is enabled, set controlnet_conditioning_scale=0")
  369. values.controlnet_conditioning_scale = 0
  370. if values.use_extender:
  371. logger.info("Extender is enabled, set sd_strength=1")
  372. values.sd_strength = 1.0
  373. if values.enable_brushnet:
  374. logger.info("BrushNet is enabled, set enable_controlnet=False")
  375. if values.enable_controlnet:
  376. values.enable_controlnet = False
  377. if values.sd_lcm_lora:
  378. logger.info("BrushNet is enabled, set sd_lcm_lora=False")
  379. values.sd_lcm_lora = False
  380. if values.enable_controlnet:
  381. logger.info("ControlNet is enabled, set enable_brushnet=False")
  382. if values.enable_brushnet:
  383. values.enable_brushnet = False
  384. return values
  385. class RunPluginRequest(BaseModel):
  386. name: str
  387. image: str = Field(..., description="base64 encoded image")
  388. clicks: List[List[int]] = Field(
  389. [], description="Clicks for interactive seg, [[x,y,0/1], [x2,y2,0/1]]"
  390. )
  391. scale: float = Field(2.0, description="Scale for upscaling")
  392. MediaTab = Literal["input", "output", "mask"]
  393. class MediasResponse(BaseModel):
  394. name: str
  395. height: int
  396. width: int
  397. ctime: float
  398. mtime: float
  399. class GenInfoResponse(BaseModel):
  400. prompt: str = ""
  401. negative_prompt: str = ""
  402. class ServerConfigResponse(BaseModel):
  403. plugins: List[PluginInfo]
  404. modelInfos: List[ModelInfo]
  405. removeBGModel: RemoveBGModel
  406. removeBGModels: List[RemoveBGModel]
  407. realesrganModel: RealESRGANModel
  408. realesrganModels: List[RealESRGANModel]
  409. interactiveSegModel: InteractiveSegModel
  410. interactiveSegModels: List[InteractiveSegModel]
  411. enableFileManager: bool
  412. enableAutoSaving: bool
  413. enableControlnet: bool
  414. controlnetMethod: Optional[str]
  415. disableModelSwitch: bool
  416. isDesktop: bool
  417. samplers: List[str]
  418. class SwitchModelRequest(BaseModel):
  419. name: str
  420. class SwitchPluginModelRequest(BaseModel):
  421. plugin_name: str
  422. model_name: str
  423. AdjustMaskOperate = Literal["expand", "shrink", "reverse"]
  424. class AdjustMaskRequest(BaseModel):
  425. mask: str = Field(
  426. ..., description="base64 encoded mask. 255 means area to do inpaint"
  427. )
  428. operate: AdjustMaskOperate = Field(..., description="expand/shrink/reverse")
  429. kernel_size: int = Field(5, description="Kernel size for expanding mask")