bfl_client.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. """BFL FLUX HTTP 客户端 — 异步提交 + 轮询。
  2. 文档: https://docs.bfl.ai/quick_start/generating_images
  3. """
  4. from __future__ import annotations
  5. import os
  6. from typing import Any
  7. import httpx
  8. from dotenv import load_dotenv
  9. _ = load_dotenv()
  10. DEFAULT_API_BASE = "https://api.bfl.ai/v1"
  11. def _api_key() -> str:
  12. key = os.environ.get("BFL_API_KEY", "").strip()
  13. if not key:
  14. raise ValueError("缺少环境变量 BFL_API_KEY")
  15. return key
  16. def _headers() -> dict[str, str]:
  17. return {
  18. "accept": "application/json",
  19. "x-key": _api_key(),
  20. "Content-Type": "application/json",
  21. }
  22. def submit_generation(
  23. *,
  24. model: str,
  25. prompt: str,
  26. width: int | None = None,
  27. height: int | None = None,
  28. parameters: dict[str, Any] | None = None,
  29. ) -> dict[str, Any]:
  30. """POST {BFL_API_BASE}/{model},返回含 id、polling_url 等(以 BFL 响应为准)。"""
  31. base = os.environ.get("BFL_API_BASE", DEFAULT_API_BASE).rstrip("/")
  32. model_path = model.strip().lstrip("/")
  33. url = f"{base}/{model_path}"
  34. body: dict[str, Any] = dict(parameters) if parameters else {}
  35. body["prompt"] = prompt
  36. if width is not None:
  37. body["width"] = width
  38. if height is not None:
  39. body["height"] = height
  40. with httpx.Client(timeout=120.0) as client:
  41. r = client.post(url, headers=_headers(), json=body)
  42. try:
  43. data = r.json()
  44. except Exception:
  45. r.raise_for_status()
  46. raise RuntimeError(r.text[:2000]) from None
  47. if r.status_code >= 400:
  48. err = data.get("detail") if isinstance(data, dict) else None
  49. msg = err if err is not None else str(data)
  50. raise RuntimeError(f"BFL HTTP {r.status_code}: {msg}")
  51. if not isinstance(data, dict):
  52. raise RuntimeError("提交响应不是 JSON 对象")
  53. return data
  54. def poll_result(*, polling_url: str, request_id: str) -> dict[str, Any]:
  55. """GET polling_url,Query: id=request_id(与官方示例一致)。"""
  56. with httpx.Client(timeout=60.0) as client:
  57. r = client.get(
  58. polling_url.strip(),
  59. headers={
  60. "accept": "application/json",
  61. "x-key": _api_key(),
  62. },
  63. params={"id": request_id.strip()},
  64. )
  65. try:
  66. data = r.json()
  67. except Exception:
  68. r.raise_for_status()
  69. raise RuntimeError(r.text[:2000]) from None
  70. if r.status_code >= 400:
  71. msg = data if isinstance(data, dict) else str(data)
  72. raise RuntimeError(f"BFL poll HTTP {r.status_code}: {msg}")
  73. if not isinstance(data, dict):
  74. raise RuntimeError("轮询响应不是 JSON 对象")
  75. return data