| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231 |
- """测试 Midjourney 代理工具 — Router POST /run_tool
- 本地服务将 JSON 原样转发至 MIDJOURNEY_API_BASE 上你已实现的三个接口:
- POST /submit_job cookie, prompt, user_id, mode(relaxed|fast)
- POST /query_job_status cookie, job_id
- POST /get_image_urls job_id → 四张图链接
- 用法:
- 1. tools/local/midjourney/.env:MIDJOURNEY_API_BASE
- 2. uv run python -m tool_agent
- 3. uv run python tests/test_midjourney.py
- 端到端(可选):设置 MIDJOURNEY_TEST_COOKIE、MIDJOURNEY_TEST_USER_ID 后脚本会
- submit → 轮询 query → get_image_urls;否则仅校验工具已注册并退出 0。
- 环境变量:
- TOOL_AGENT_ROUTER_URL
- MIDJOURNEY_SUBMIT_TOOL_ID 默认 midjourney_submit_job
- MIDJOURNEY_QUERY_TOOL_ID 默认 midjourney_query_job_status
- MIDJOURNEY_GET_URLS_TOOL_ID 默认 midjourney_get_image_urls
- MIDJOURNEY_TEST_COOKIE / MIDJOURNEY_TEST_USER_ID / MIDJOURNEY_TEST_PROMPT / MIDJOURNEY_TEST_MODE
- MIDJOURNEY_POLL_INTERVAL_S / MIDJOURNEY_POLL_MAX_WAIT_S
- """
- from __future__ import annotations
- import io
- import os
- import sys
- import time
- from typing import Any
- if sys.platform == "win32":
- _out = sys.stdout
- if isinstance(_out, io.TextIOWrapper):
- _out.reconfigure(encoding="utf-8")
- import httpx
- ROUTER_URL = os.environ.get("TOOL_AGENT_ROUTER_URL", "http://127.0.0.1:8001")
- T_SUBMIT = os.environ.get("MIDJOURNEY_SUBMIT_TOOL_ID", "midjourney_submit_job")
- T_QUERY = os.environ.get("MIDJOURNEY_QUERY_TOOL_ID", "midjourney_query_job_status")
- T_URLS = os.environ.get("MIDJOURNEY_GET_URLS_TOOL_ID", "midjourney_get_image_urls")
- TEST_COOKIE = os.environ.get("MIDJOURNEY_TEST_COOKIE", "").strip()
- TEST_USER_ID = os.environ.get("MIDJOURNEY_TEST_USER_ID", "").strip()
- TEST_PROMPT = os.environ.get("MIDJOURNEY_TEST_PROMPT", "a red apple on white background --v 6")
- TEST_MODE = os.environ.get("MIDJOURNEY_TEST_MODE", "fast").strip().lower()
- POLL_INTERVAL_S = float(os.environ.get("MIDJOURNEY_POLL_INTERVAL_S", "3"))
- POLL_MAX_WAIT_S = float(os.environ.get("MIDJOURNEY_POLL_MAX_WAIT_S", "600"))
- def run_tool(tool_id: str, params: dict[str, Any], timeout: float = 120.0) -> Any:
- resp = httpx.post(
- f"{ROUTER_URL}/run_tool",
- json={"tool_id": tool_id, "params": params},
- timeout=timeout,
- )
- resp.raise_for_status()
- body = resp.json()
- if body.get("status") != "success":
- raise RuntimeError(body.get("error") or str(body))
- result = body.get("result")
- if isinstance(result, dict) and result.get("status") == "error":
- raise RuntimeError(result.get("error", str(result)))
- return result
- def _extract_job_id(data: dict[str, Any]) -> str | None:
- if not isinstance(data, dict):
- return None
- for key in ("job_id", "jobId", "id", "task_id", "taskId"):
- v = data.get(key)
- if v is not None and str(v).strip():
- return str(v).strip()
- inner = data.get("data")
- if isinstance(inner, dict):
- return _extract_job_id(inner)
- return None
- def _status_terminal_ok(data: dict[str, Any]) -> bool:
- if not isinstance(data, dict):
- return False
- s = str(
- data.get("status")
- or data.get("job_status")
- or data.get("jobStatus")
- or data.get("state")
- or ""
- ).lower()
- if not s and isinstance(data.get("data"), dict):
- return _status_terminal_ok(data["data"])
- return any(k in s for k in ("complete", "success", "done", "finished", "succeed", "ready"))
- def _status_terminal_fail(data: dict[str, Any]) -> bool:
- if not isinstance(data, dict):
- return False
- s = str(data.get("status") or data.get("job_status") or data.get("state") or "").lower()
- return any(k in s for k in ("fail", "error", "cancel", "canceled", "cancelled"))
- def _extract_url_list(payload: Any) -> list[str]:
- if isinstance(payload, list):
- return [str(x) for x in payload if isinstance(x, str) and x.startswith("http")]
- if not isinstance(payload, dict):
- return []
- for key in ("image_urls", "urls", "images", "data"):
- v = payload.get(key)
- if isinstance(v, list):
- out = [str(x) for x in v if isinstance(x, str) and x.startswith("http")]
- if out:
- return out
- if isinstance(v, dict):
- nested = _extract_url_list(v)
- if nested:
- return nested
- return _extract_url_list(payload.get("data"))
- def main() -> None:
- print("=" * 50)
- print("测试 Midjourney(submit / query / get_image_urls)")
- print("=" * 50)
- print(f"ROUTER_URL: {ROUTER_URL}")
- try:
- r = httpx.get(f"{ROUTER_URL}/health", timeout=3)
- print(f"Router 状态: {r.json()}")
- except httpx.ConnectError:
- print(f"无法连接 Router ({ROUTER_URL}),请先: uv run python -m tool_agent")
- sys.exit(1)
- print("\n--- 校验工具已注册 ---")
- tr = httpx.get(f"{ROUTER_URL}/tools", timeout=30)
- tr.raise_for_status()
- tools = tr.json().get("tools", [])
- ids = {t["tool_id"] for t in tools}
- for tid in (T_SUBMIT, T_QUERY, T_URLS):
- if tid not in ids:
- print(f"错误: {tid!r} 不在 GET /tools 中。示例: {sorted(ids)[:25]}...")
- sys.exit(1)
- meta = next(t for t in tools if t["tool_id"] == tid)
- print(f" {tid}: {meta.get('name', '')} (state={meta.get('state')})")
- if not TEST_COOKIE or not TEST_USER_ID:
- print(
- "\n未设置 MIDJOURNEY_TEST_COOKIE 与 MIDJOURNEY_TEST_USER_ID,跳过端到端;"
- "工具注册检查已通过,退出 0。"
- )
- return
- if TEST_MODE not in ("relaxed", "fast"):
- print(f"错误: MIDJOURNEY_TEST_MODE 须为 relaxed 或 fast,当前: {TEST_MODE!r}")
- sys.exit(1)
- print("\n--- midjourney_submit_job ---")
- try:
- sub = run_tool(
- T_SUBMIT,
- {
- "cookie": TEST_COOKIE,
- "prompt": TEST_PROMPT,
- "user_id": TEST_USER_ID,
- "mode": TEST_MODE,
- },
- timeout=180.0,
- )
- except (RuntimeError, httpx.HTTPError) as e:
- print(f"错误: {e}")
- sys.exit(1)
- if not isinstance(sub, dict):
- print(f"错误: submit 返回非 object: {type(sub)}")
- sys.exit(1)
- job_id = _extract_job_id(sub)
- if not job_id:
- print(f"错误: 无法从 submit 响应解析 job_id: {sub}")
- sys.exit(1)
- print(f"job_id: {job_id}")
- print("\n--- midjourney_query_job_status 轮询 ---")
- deadline = time.monotonic() + POLL_MAX_WAIT_S
- last: dict[str, Any] = {}
- while time.monotonic() < deadline:
- time.sleep(POLL_INTERVAL_S)
- try:
- q = run_tool(
- T_QUERY,
- {"cookie": TEST_COOKIE, "job_id": job_id},
- timeout=120.0,
- )
- except (RuntimeError, httpx.HTTPError) as e:
- print(f"轮询错误: {e}")
- sys.exit(1)
- last = q if isinstance(q, dict) else {}
- st = last.get("status") or last.get("job_status") or last.get("state")
- print(f" status: {st}")
- if _status_terminal_fail(last):
- print(f"任务失败: {last}")
- sys.exit(1)
- if _status_terminal_ok(last):
- break
- else:
- print(f"等待超时 ({POLL_MAX_WAIT_S}s),最后响应: {last}")
- sys.exit(1)
- print("\n--- midjourney_get_image_urls ---")
- try:
- urls_payload = run_tool(T_URLS, {"job_id": job_id}, timeout=120.0)
- except (RuntimeError, httpx.HTTPError) as e:
- print(f"错误: {e}")
- sys.exit(1)
- urls = _extract_url_list(urls_payload)
- if len(urls) < 4:
- print(f"警告: 期望至少 4 个 http 链接,实际 {len(urls)};原始: {str(urls_payload)[:500]}")
- if len(urls) == 0:
- sys.exit(1)
- for i, u in enumerate(urls[:4], 1):
- print(f" [{i}] {u[:96]}...")
- print("\n测试通过!")
- if __name__ == "__main__":
- main()
|