test_flux.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. """测试 BFL FLUX 异步生图 — 通过 Router POST /run_tool
  2. 官方流程:先 POST 提交任务拿到 id + polling_url,再轮询 polling_url 直至 Ready。
  3. 文档: https://docs.bfl.ai/quick_start/generating_images
  4. 用法:
  5. 1. 配置 tools/local/flux/.env:BFL_API_KEY
  6. 2. uv run python -m tool_agent
  7. 3. uv run python tests/test_flux.py
  8. 模型切换:
  9. FLUX_TEST_MODEL=flux-2-max uv run python tests/test_flux.py
  10. (model 为路径段,如 flux-2-pro-preview、flux-2-pro、flux-dev 等,见官方 Available Endpoints)
  11. 环境变量:
  12. TOOL_AGENT_ROUTER_URL 默认 http://127.0.0.1:8001
  13. FLUX_SUBMIT_TOOL_ID 默认 flux_submit
  14. FLUX_QUERY_TOOL_ID 默认 flux_query
  15. FLUX_TEST_MODEL 默认 flux-2-pro-preview
  16. FLUX_TEST_PROMPT 覆盖默认短提示词
  17. FLUX_POLL_INTERVAL_S 默认 1.0
  18. FLUX_POLL_MAX_WAIT_S 默认 300
  19. """
  20. from __future__ import annotations
  21. import io
  22. import os
  23. import sys
  24. import time
  25. from typing import Any
  26. if sys.platform == "win32":
  27. _out = sys.stdout
  28. if isinstance(_out, io.TextIOWrapper):
  29. _out.reconfigure(encoding="utf-8")
  30. import httpx
  31. ROUTER_URL = os.environ.get("TOOL_AGENT_ROUTER_URL", "http://127.0.0.1:8001")
  32. SUBMIT_TOOL = os.environ.get("FLUX_SUBMIT_TOOL_ID", "flux_submit")
  33. QUERY_TOOL = os.environ.get("FLUX_QUERY_TOOL_ID", "flux_query")
  34. FLUX_MODEL = os.environ.get("FLUX_TEST_MODEL", "flux-2-pro-preview").strip()
  35. TEST_PROMPT = os.environ.get(
  36. "FLUX_TEST_PROMPT",
  37. "A tiny red apple on white background, simple product photo, minimal",
  38. )
  39. POLL_INTERVAL_S = float(os.environ.get("FLUX_POLL_INTERVAL_S", "1.0"))
  40. POLL_MAX_WAIT_S = float(os.environ.get("FLUX_POLL_MAX_WAIT_S", "300"))
  41. def run_tool(tool_id: str, params: dict[str, Any], timeout: float = 120.0) -> dict[str, Any]:
  42. resp = httpx.post(
  43. f"{ROUTER_URL}/run_tool",
  44. json={"tool_id": tool_id, "params": params},
  45. timeout=timeout,
  46. )
  47. resp.raise_for_status()
  48. body = resp.json()
  49. if body.get("status") != "success":
  50. raise RuntimeError(body.get("error") or str(body))
  51. result = body.get("result")
  52. if isinstance(result, dict) and result.get("status") == "error":
  53. raise RuntimeError(result.get("error", str(result)))
  54. return result if isinstance(result, dict) else {}
  55. def _poll_terminal_success(data: dict[str, Any]) -> bool:
  56. s = str(data.get("status") or "").strip()
  57. return s.lower() == "ready"
  58. def _poll_terminal_failure(data: dict[str, Any]) -> bool:
  59. s = str(data.get("status") or "").strip().lower()
  60. return s in ("error", "failed")
  61. def _sample_url(data: dict[str, Any]) -> str | None:
  62. r = data.get("result")
  63. if isinstance(r, dict):
  64. u = r.get("sample")
  65. if isinstance(u, str) and u.startswith("http"):
  66. return u
  67. return None
  68. def main() -> None:
  69. print("=" * 50)
  70. print("测试 FLUX(BFL 异步 API + 模型可切换)")
  71. print("=" * 50)
  72. print(f"ROUTER_URL: {ROUTER_URL}")
  73. print(f"model: {FLUX_MODEL}")
  74. try:
  75. r = httpx.get(f"{ROUTER_URL}/health", timeout=3)
  76. print(f"Router 状态: {r.json()}")
  77. except httpx.ConnectError:
  78. print(f"无法连接 Router ({ROUTER_URL}),请先: uv run python -m tool_agent")
  79. sys.exit(1)
  80. print("\n--- 校验工具已注册 ---")
  81. tr = httpx.get(f"{ROUTER_URL}/tools", timeout=30)
  82. tr.raise_for_status()
  83. tools = tr.json().get("tools", [])
  84. ids = {t["tool_id"] for t in tools}
  85. for tid in (SUBMIT_TOOL, QUERY_TOOL):
  86. if tid not in ids:
  87. print(f"错误: {tid!r} 不在 GET /tools 中。示例 id: {sorted(ids)[:20]}...")
  88. sys.exit(1)
  89. meta = next(t for t in tools if t["tool_id"] == tid)
  90. print(f" {tid}: {meta.get('name', '')} (state={meta.get('state')})")
  91. props = (next(t for t in tools if t["tool_id"] == SUBMIT_TOOL).get("input_schema") or {}).get(
  92. "properties"
  93. ) or {}
  94. if "model" in props:
  95. print(" flux_submit input_schema 已声明 model")
  96. else:
  97. print(" 提示: flux_submit 宜在注册表中声明 model 以便切换端点")
  98. print("\n--- flux_submit ---")
  99. submit_params: dict[str, Any] = {
  100. "model": FLUX_MODEL,
  101. "prompt": TEST_PROMPT,
  102. "width": 512,
  103. "height": 512,
  104. }
  105. try:
  106. sub = run_tool(SUBMIT_TOOL, submit_params, timeout=120.0)
  107. except (RuntimeError, httpx.HTTPError) as e:
  108. print(f"错误: {e}")
  109. sys.exit(1)
  110. print(f"提交返回 keys: {list(sub.keys())}")
  111. req_id = sub.get("id") or sub.get("request_id")
  112. poll_url = sub.get("polling_url")
  113. if not req_id or not poll_url:
  114. print(f"错误: 缺少 id 或 polling_url: {sub}")
  115. sys.exit(1)
  116. print(f"request id: {req_id}")
  117. print(f"polling_url: {poll_url[:80]}...")
  118. print("\n--- flux_query 轮询 ---")
  119. deadline = time.monotonic() + POLL_MAX_WAIT_S
  120. last: dict[str, Any] = {}
  121. while time.monotonic() < deadline:
  122. time.sleep(POLL_INTERVAL_S)
  123. try:
  124. last = run_tool(
  125. QUERY_TOOL,
  126. {"polling_url": str(poll_url), "request_id": str(req_id)},
  127. timeout=60.0,
  128. )
  129. except (RuntimeError, httpx.HTTPError) as e:
  130. print(f"轮询错误: {e}")
  131. sys.exit(1)
  132. st = last.get("status")
  133. print(f" status: {st}")
  134. if _poll_terminal_failure(last):
  135. print(f"生成失败: {last}")
  136. sys.exit(1)
  137. if _poll_terminal_success(last):
  138. url = _sample_url(last)
  139. if url:
  140. print(f"\n图片 URL(signed,约 10 分钟内有效): {url[:100]}...")
  141. print("\n测试通过!")
  142. return
  143. print(f"\n等待超时 ({POLL_MAX_WAIT_S}s),最后一次: {last}")
  144. sys.exit(1)
  145. if __name__ == "__main__":
  146. main()