main.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. """BFL FLUX 本地封装 — 异步生图(提交 + 轮询)。
  2. 环境变量:
  3. BFL_API_KEY 必填,请求头 x-key
  4. BFL_API_BASE 可选,默认 https://api.bfl.ai/v1(全球端点;也可用 api.eu.bfl.ai/v1 等)
  5. 接口:
  6. GET /health
  7. POST /submit 提交生成;body.model 为端点路径段,如 flux-2-pro-preview
  8. POST /query 轮询;polling_url、request_id 来自提交响应
  9. 文档: https://docs.bfl.ai/quick_start/generating_images
  10. """
  11. from __future__ import annotations
  12. import argparse
  13. from typing import Any
  14. import uvicorn
  15. from fastapi import FastAPI, HTTPException
  16. from pydantic import BaseModel, Field
  17. from bfl_client import poll_result, submit_generation
  18. app = FastAPI(title="BFL FLUX API Proxy")
  19. class SubmitRequest(BaseModel):
  20. model: str = Field(
  21. ...,
  22. description="模型端点路径段,如 flux-2-pro-preview、flux-2-max、flux-dev(对应 /v1/{model})",
  23. )
  24. prompt: str = Field(..., description="文生图提示词")
  25. width: int | None = Field(default=None, description="输出宽度(像素)")
  26. height: int | None = Field(default=None, description="输出高度(像素)")
  27. parameters: dict[str, Any] | None = Field(
  28. default=None,
  29. description="合并进请求体的额外字段(官方各模型可选参数)",
  30. )
  31. class QueryRequest(BaseModel):
  32. polling_url: str = Field(..., description="提交响应中的 polling_url,须原样使用")
  33. request_id: str = Field(..., description="提交响应中的 id")
  34. @app.get("/health")
  35. def health() -> dict[str, str]:
  36. return {"status": "ok"}
  37. @app.post("/submit")
  38. def submit(req: SubmitRequest) -> dict[str, Any]:
  39. try:
  40. return submit_generation(
  41. model=req.model,
  42. prompt=req.prompt,
  43. width=req.width,
  44. height=req.height,
  45. parameters=req.parameters,
  46. )
  47. except ValueError as e:
  48. raise HTTPException(status_code=503, detail=str(e)) from e
  49. except RuntimeError as e:
  50. raise HTTPException(status_code=502, detail=str(e)) from e
  51. except Exception as e:
  52. raise HTTPException(status_code=502, detail=str(e)) from e
  53. @app.post("/query")
  54. def query(req: QueryRequest) -> dict[str, Any]:
  55. try:
  56. return poll_result(polling_url=req.polling_url, request_id=req.request_id)
  57. except ValueError as e:
  58. raise HTTPException(status_code=503, detail=str(e)) from e
  59. except RuntimeError as e:
  60. raise HTTPException(status_code=502, detail=str(e)) from e
  61. except Exception as e:
  62. raise HTTPException(status_code=502, detail=str(e)) from e
  63. if __name__ == "__main__":
  64. parser = argparse.ArgumentParser()
  65. parser.add_argument("--port", type=int, default=8001)
  66. args = parser.parse_args()
  67. uvicorn.run(app, host="0.0.0.0", port=args.port)