| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103 |
- from fastapi import FastAPI, HTTPException
- from pydantic import BaseModel, Field
- from typing import Optional
- import uvicorn
- import sys
- import os
- from liblibai_client import LibLibAIClient
- app = FastAPI(title="LibLib AI 综合生成工具 API")
- class SearchRequest(BaseModel):
- keyword: str = Field(..., description="搜索关键词")
- class DetailRequest(BaseModel):
- content_link: Optional[str] = Field(None, description="完整内容链接")
- uuid: Optional[str] = Field(None, description="模型 uuid")
- version_uuid: Optional[str] = Field(None, description="版本 versionUuid")
- class ControlNetItem(BaseModel):
- mode: str = Field(..., description="控制类型: canny, softedge, lineart, openpose, depth")
- image: str = Field(..., description="参考图片 URL 或 Base64")
- weight: float = Field(1.0, description="控制权重,通常为 1.0")
- class GenerateRequest(BaseModel):
- mode: str = Field("text2img", description="生成模式:text2img, img2img, inpaint, instantid, 或者是单个控制网如 canny, openpose 等")
- prompt: str
- image: Optional[str] = Field(None, description="单图模式的参考图,或者图生图的原图。")
- control_nets: Optional[list[ControlNetItem]] = Field(None, description="多路并发控制网列表。若提供此项,忽略外层的单图 mode 设置。")
- mask_image: Optional[str] = Field(None, description="蒙版图 (仅 inpaint 模式需要)")
- pose_image: Optional[str] = Field(None, description="姿态参考图 (仅 instantid 模式需要)")
- negative_prompt: str = "lowres, bad anatomy, text, error"
- width: int = 512
- height: int = 512
- steps: int = 20
- cfg_scale: float = 7.0
- img_count: int = 1
- base_model_uuid: Optional[str] = Field(None, description="动态传入的底模 UUID(versionUuid 即 checkpointId),覆盖默认底模")
- @app.post("/generate")
- async def generate(req: GenerateRequest):
- try:
- client = LibLibAIClient()
- control_nets_dicts = [c.dict() for c in req.control_nets] if req.control_nets else None
-
- result = client.generate_advanced(
- mode=req.mode.lower(),
- prompt=req.prompt,
- image=req.image,
- control_nets=control_nets_dicts,
- mask_image=req.mask_image,
- pose_image=req.pose_image,
- negative_prompt=req.negative_prompt,
- width=req.width,
- height=req.height,
- steps=req.steps,
- cfg_scale=req.cfg_scale,
- img_count=req.img_count,
- base_model_uuid=req.base_model_uuid
- )
- return result
- except ValueError as ve:
- raise HTTPException(status_code=400, detail=str(ve))
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
- @app.get("/health")
- async def health():
- return {"status": "ok"}
- @app.post("/search_models")
- async def search_models(req: SearchRequest):
- try:
- client = LibLibAIClient()
- return client.search_models(keyword=req.keyword)
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
- @app.post("/model_detail")
- async def get_model_detail(req: DetailRequest):
- try:
- client = LibLibAIClient()
- return client.get_model_detail(
- content_link=req.content_link,
- uuid=req.uuid,
- version_uuid=req.version_uuid
- )
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
- @app.get("/uuid_matching_rules")
- async def get_uuid_matching_rules():
- try:
- base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
- doc_path = os.path.join(base_dir, "docs", "liblibai_uuid_matching_rules.md")
- with open(doc_path, "r", encoding="utf-8") as f:
- content = f.read()
- return {"content": content}
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
- if __name__ == "__main__":
- port = int(sys.argv[sys.argv.index("--port") + 1]) if "--port" in sys.argv else 8001
- uvicorn.run(app, host="0.0.0.0", port=port)
|