|
|
@@ -0,0 +1,177 @@
|
|
|
+"""测试 BFL FLUX 异步生图 — 通过 Router POST /run_tool
|
|
|
+
|
|
|
+官方流程:先 POST 提交任务拿到 id + polling_url,再轮询 polling_url 直至 Ready。
|
|
|
+文档: https://docs.bfl.ai/quick_start/generating_images
|
|
|
+
|
|
|
+用法:
|
|
|
+ 1. 配置 tools/local/flux/.env:BFL_API_KEY
|
|
|
+ 2. uv run python -m tool_agent
|
|
|
+ 3. uv run python tests/test_flux.py
|
|
|
+
|
|
|
+模型切换:
|
|
|
+ FLUX_TEST_MODEL=flux-2-max uv run python tests/test_flux.py
|
|
|
+ (model 为路径段,如 flux-2-pro-preview、flux-2-pro、flux-dev 等,见官方 Available Endpoints)
|
|
|
+
|
|
|
+环境变量:
|
|
|
+ TOOL_AGENT_ROUTER_URL 默认 http://127.0.0.1:8001
|
|
|
+ FLUX_SUBMIT_TOOL_ID 默认 flux_submit
|
|
|
+ FLUX_QUERY_TOOL_ID 默认 flux_query
|
|
|
+ FLUX_TEST_MODEL 默认 flux-2-pro-preview
|
|
|
+ FLUX_TEST_PROMPT 覆盖默认短提示词
|
|
|
+ FLUX_POLL_INTERVAL_S 默认 1.0
|
|
|
+ FLUX_POLL_MAX_WAIT_S 默认 300
|
|
|
+"""
|
|
|
+
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+import io
|
|
|
+import os
|
|
|
+import sys
|
|
|
+import time
|
|
|
+from typing import Any
|
|
|
+
|
|
|
+if sys.platform == "win32":
|
|
|
+ _out = sys.stdout
|
|
|
+ if isinstance(_out, io.TextIOWrapper):
|
|
|
+ _out.reconfigure(encoding="utf-8")
|
|
|
+
|
|
|
+import httpx
|
|
|
+
|
|
|
+ROUTER_URL = os.environ.get("TOOL_AGENT_ROUTER_URL", "http://127.0.0.1:8001")
|
|
|
+SUBMIT_TOOL = os.environ.get("FLUX_SUBMIT_TOOL_ID", "flux_submit")
|
|
|
+QUERY_TOOL = os.environ.get("FLUX_QUERY_TOOL_ID", "flux_query")
|
|
|
+FLUX_MODEL = os.environ.get("FLUX_TEST_MODEL", "flux-2-pro-preview").strip()
|
|
|
+TEST_PROMPT = os.environ.get(
|
|
|
+ "FLUX_TEST_PROMPT",
|
|
|
+ "A tiny red apple on white background, simple product photo, minimal",
|
|
|
+)
|
|
|
+POLL_INTERVAL_S = float(os.environ.get("FLUX_POLL_INTERVAL_S", "1.0"))
|
|
|
+POLL_MAX_WAIT_S = float(os.environ.get("FLUX_POLL_MAX_WAIT_S", "300"))
|
|
|
+
|
|
|
+
|
|
|
+def run_tool(tool_id: str, params: dict[str, Any], timeout: float = 120.0) -> dict[str, Any]:
|
|
|
+ resp = httpx.post(
|
|
|
+ f"{ROUTER_URL}/run_tool",
|
|
|
+ json={"tool_id": tool_id, "params": params},
|
|
|
+ timeout=timeout,
|
|
|
+ )
|
|
|
+ resp.raise_for_status()
|
|
|
+ body = resp.json()
|
|
|
+ if body.get("status") != "success":
|
|
|
+ raise RuntimeError(body.get("error") or str(body))
|
|
|
+ result = body.get("result")
|
|
|
+ if isinstance(result, dict) and result.get("status") == "error":
|
|
|
+ raise RuntimeError(result.get("error", str(result)))
|
|
|
+ return result if isinstance(result, dict) else {}
|
|
|
+
|
|
|
+
|
|
|
+def _poll_terminal_success(data: dict[str, Any]) -> bool:
|
|
|
+ s = str(data.get("status") or "").strip()
|
|
|
+ return s.lower() == "ready"
|
|
|
+
|
|
|
+
|
|
|
+def _poll_terminal_failure(data: dict[str, Any]) -> bool:
|
|
|
+ s = str(data.get("status") or "").strip().lower()
|
|
|
+ return s in ("error", "failed")
|
|
|
+
|
|
|
+
|
|
|
+def _sample_url(data: dict[str, Any]) -> str | None:
|
|
|
+ r = data.get("result")
|
|
|
+ if isinstance(r, dict):
|
|
|
+ u = r.get("sample")
|
|
|
+ if isinstance(u, str) and u.startswith("http"):
|
|
|
+ return u
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+def main() -> None:
|
|
|
+ print("=" * 50)
|
|
|
+ print("测试 FLUX(BFL 异步 API + 模型可切换)")
|
|
|
+ print("=" * 50)
|
|
|
+ print(f"ROUTER_URL: {ROUTER_URL}")
|
|
|
+ print(f"model: {FLUX_MODEL}")
|
|
|
+
|
|
|
+ try:
|
|
|
+ r = httpx.get(f"{ROUTER_URL}/health", timeout=3)
|
|
|
+ print(f"Router 状态: {r.json()}")
|
|
|
+ except httpx.ConnectError:
|
|
|
+ print(f"无法连接 Router ({ROUTER_URL}),请先: uv run python -m tool_agent")
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+ print("\n--- 校验工具已注册 ---")
|
|
|
+ tr = httpx.get(f"{ROUTER_URL}/tools", timeout=30)
|
|
|
+ tr.raise_for_status()
|
|
|
+ tools = tr.json().get("tools", [])
|
|
|
+ ids = {t["tool_id"] for t in tools}
|
|
|
+ for tid in (SUBMIT_TOOL, QUERY_TOOL):
|
|
|
+ if tid not in ids:
|
|
|
+ print(f"错误: {tid!r} 不在 GET /tools 中。示例 id: {sorted(ids)[:20]}...")
|
|
|
+ sys.exit(1)
|
|
|
+ meta = next(t for t in tools if t["tool_id"] == tid)
|
|
|
+ print(f" {tid}: {meta.get('name', '')} (state={meta.get('state')})")
|
|
|
+
|
|
|
+ props = (next(t for t in tools if t["tool_id"] == SUBMIT_TOOL).get("input_schema") or {}).get(
|
|
|
+ "properties"
|
|
|
+ ) or {}
|
|
|
+ if "model" in props:
|
|
|
+ print(" flux_submit input_schema 已声明 model")
|
|
|
+ else:
|
|
|
+ print(" 提示: flux_submit 宜在注册表中声明 model 以便切换端点")
|
|
|
+
|
|
|
+ print("\n--- flux_submit ---")
|
|
|
+ submit_params: dict[str, Any] = {
|
|
|
+ "model": FLUX_MODEL,
|
|
|
+ "prompt": TEST_PROMPT,
|
|
|
+ "width": 512,
|
|
|
+ "height": 512,
|
|
|
+ }
|
|
|
+ try:
|
|
|
+ sub = run_tool(SUBMIT_TOOL, submit_params, timeout=120.0)
|
|
|
+ except (RuntimeError, httpx.HTTPError) as e:
|
|
|
+ print(f"错误: {e}")
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+ print(f"提交返回 keys: {list(sub.keys())}")
|
|
|
+ req_id = sub.get("id") or sub.get("request_id")
|
|
|
+ poll_url = sub.get("polling_url")
|
|
|
+ if not req_id or not poll_url:
|
|
|
+ print(f"错误: 缺少 id 或 polling_url: {sub}")
|
|
|
+ sys.exit(1)
|
|
|
+ print(f"request id: {req_id}")
|
|
|
+ print(f"polling_url: {poll_url[:80]}...")
|
|
|
+
|
|
|
+ print("\n--- flux_query 轮询 ---")
|
|
|
+ deadline = time.monotonic() + POLL_MAX_WAIT_S
|
|
|
+ last: dict[str, Any] = {}
|
|
|
+
|
|
|
+ while time.monotonic() < deadline:
|
|
|
+ time.sleep(POLL_INTERVAL_S)
|
|
|
+ try:
|
|
|
+ last = run_tool(
|
|
|
+ QUERY_TOOL,
|
|
|
+ {"polling_url": str(poll_url), "request_id": str(req_id)},
|
|
|
+ timeout=60.0,
|
|
|
+ )
|
|
|
+ except (RuntimeError, httpx.HTTPError) as e:
|
|
|
+ print(f"轮询错误: {e}")
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+ st = last.get("status")
|
|
|
+ print(f" status: {st}")
|
|
|
+
|
|
|
+ if _poll_terminal_failure(last):
|
|
|
+ print(f"生成失败: {last}")
|
|
|
+ sys.exit(1)
|
|
|
+ if _poll_terminal_success(last):
|
|
|
+ url = _sample_url(last)
|
|
|
+ if url:
|
|
|
+ print(f"\n图片 URL(signed,约 10 分钟内有效): {url[:100]}...")
|
|
|
+ print("\n测试通过!")
|
|
|
+ return
|
|
|
+
|
|
|
+ print(f"\n等待超时 ({POLL_MAX_WAIT_S}s),最后一次: {last}")
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|