main.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from fastapi import FastAPI, HTTPException
  2. from pydantic import BaseModel, Field
  3. from typing import Optional
  4. import uvicorn
  5. import sys
  6. import os
  7. from liblibai_client import LibLibAIClient
  8. app = FastAPI(title="LibLib AI 综合生成工具 API")
  9. class SearchRequest(BaseModel):
  10. keyword: str = Field(..., description="搜索关键词")
  11. class DetailRequest(BaseModel):
  12. content_link: Optional[str] = Field(None, description="完整内容链接")
  13. uuid: Optional[str] = Field(None, description="模型 uuid")
  14. version_uuid: Optional[str] = Field(None, description="版本 versionUuid")
  15. class ControlNetItem(BaseModel):
  16. mode: str = Field(..., description="控制类型: canny, softedge, lineart, openpose, depth")
  17. image: str = Field(..., description="参考图片 URL 或 Base64")
  18. weight: float = Field(1.0, description="控制权重,通常为 1.0")
  19. class GenerateRequest(BaseModel):
  20. mode: str = Field("text2img", description="生成模式:text2img, img2img, inpaint, instantid, 或者是单个控制网如 canny, openpose 等")
  21. prompt: str
  22. image: Optional[str] = Field(None, description="单图模式的参考图,或者图生图的原图。")
  23. control_nets: Optional[list[ControlNetItem]] = Field(None, description="多路并发控制网列表。若提供此项,忽略外层的单图 mode 设置。")
  24. mask_image: Optional[str] = Field(None, description="蒙版图 (仅 inpaint 模式需要)")
  25. pose_image: Optional[str] = Field(None, description="姿态参考图 (仅 instantid 模式需要)")
  26. negative_prompt: str = "lowres, bad anatomy, text, error"
  27. width: int = 512
  28. height: int = 512
  29. steps: int = 20
  30. cfg_scale: float = 7.0
  31. img_count: int = 1
  32. base_model_uuid: Optional[str] = Field(None, description="动态传入的底模 UUID(versionUuid 即 checkpointId),覆盖默认底模")
  33. @app.post("/generate")
  34. async def generate(req: GenerateRequest):
  35. try:
  36. client = LibLibAIClient()
  37. control_nets_dicts = [c.dict() for c in req.control_nets] if req.control_nets else None
  38. result = client.generate_advanced(
  39. mode=req.mode.lower(),
  40. prompt=req.prompt,
  41. image=req.image,
  42. control_nets=control_nets_dicts,
  43. mask_image=req.mask_image,
  44. pose_image=req.pose_image,
  45. negative_prompt=req.negative_prompt,
  46. width=req.width,
  47. height=req.height,
  48. steps=req.steps,
  49. cfg_scale=req.cfg_scale,
  50. img_count=req.img_count,
  51. base_model_uuid=req.base_model_uuid
  52. )
  53. return result
  54. except ValueError as ve:
  55. raise HTTPException(status_code=400, detail=str(ve))
  56. except Exception as e:
  57. raise HTTPException(status_code=500, detail=str(e))
  58. @app.get("/health")
  59. async def health():
  60. return {"status": "ok"}
  61. @app.post("/search_models")
  62. async def search_models(req: SearchRequest):
  63. try:
  64. client = LibLibAIClient()
  65. return client.search_models(keyword=req.keyword)
  66. except Exception as e:
  67. raise HTTPException(status_code=500, detail=str(e))
  68. @app.post("/model_detail")
  69. async def get_model_detail(req: DetailRequest):
  70. try:
  71. client = LibLibAIClient()
  72. return client.get_model_detail(
  73. content_link=req.content_link,
  74. uuid=req.uuid,
  75. version_uuid=req.version_uuid
  76. )
  77. except Exception as e:
  78. raise HTTPException(status_code=500, detail=str(e))
  79. @app.get("/uuid_matching_rules")
  80. async def get_uuid_matching_rules():
  81. try:
  82. base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
  83. doc_path = os.path.join(base_dir, "docs", "liblibai_uuid_matching_rules.md")
  84. with open(doc_path, "r", encoding="utf-8") as f:
  85. content = f.read()
  86. return {"content": content}
  87. except Exception as e:
  88. raise HTTPException(status_code=500, detail=str(e))
  89. if __name__ == "__main__":
  90. port = int(sys.argv[sys.argv.index("--port") + 1]) if "--port" in sys.argv else 8001
  91. uvicorn.run(app, host="0.0.0.0", port=port)