|
|
@@ -0,0 +1,231 @@
|
|
|
+"""测试 Midjourney 代理工具 — Router POST /run_tool
|
|
|
+
|
|
|
+本地服务将 JSON 原样转发至 MIDJOURNEY_API_BASE 上你已实现的三个接口:
|
|
|
+ POST /submit_job cookie, prompt, user_id, mode(relaxed|fast)
|
|
|
+ POST /query_job_status cookie, job_id
|
|
|
+ POST /get_image_urls job_id → 四张图链接
|
|
|
+
|
|
|
+用法:
|
|
|
+ 1. tools/local/midjourney/.env:MIDJOURNEY_API_BASE
|
|
|
+ 2. uv run python -m tool_agent
|
|
|
+ 3. uv run python tests/test_midjourney.py
|
|
|
+
|
|
|
+端到端(可选):设置 MIDJOURNEY_TEST_COOKIE、MIDJOURNEY_TEST_USER_ID 后脚本会
|
|
|
+ submit → 轮询 query → get_image_urls;否则仅校验工具已注册并退出 0。
|
|
|
+
|
|
|
+环境变量:
|
|
|
+ TOOL_AGENT_ROUTER_URL
|
|
|
+ MIDJOURNEY_SUBMIT_TOOL_ID 默认 midjourney_submit_job
|
|
|
+ MIDJOURNEY_QUERY_TOOL_ID 默认 midjourney_query_job_status
|
|
|
+ MIDJOURNEY_GET_URLS_TOOL_ID 默认 midjourney_get_image_urls
|
|
|
+ MIDJOURNEY_TEST_COOKIE / MIDJOURNEY_TEST_USER_ID / MIDJOURNEY_TEST_PROMPT / MIDJOURNEY_TEST_MODE
|
|
|
+ MIDJOURNEY_POLL_INTERVAL_S / MIDJOURNEY_POLL_MAX_WAIT_S
|
|
|
+"""
|
|
|
+
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+import io
|
|
|
+import os
|
|
|
+import sys
|
|
|
+import time
|
|
|
+from typing import Any
|
|
|
+
|
|
|
+if sys.platform == "win32":
|
|
|
+ _out = sys.stdout
|
|
|
+ if isinstance(_out, io.TextIOWrapper):
|
|
|
+ _out.reconfigure(encoding="utf-8")
|
|
|
+
|
|
|
+import httpx
|
|
|
+
|
|
|
+ROUTER_URL = os.environ.get("TOOL_AGENT_ROUTER_URL", "http://127.0.0.1:8001")
|
|
|
+T_SUBMIT = os.environ.get("MIDJOURNEY_SUBMIT_TOOL_ID", "midjourney_submit_job")
|
|
|
+T_QUERY = os.environ.get("MIDJOURNEY_QUERY_TOOL_ID", "midjourney_query_job_status")
|
|
|
+T_URLS = os.environ.get("MIDJOURNEY_GET_URLS_TOOL_ID", "midjourney_get_image_urls")
|
|
|
+TEST_COOKIE = os.environ.get("MIDJOURNEY_TEST_COOKIE", "").strip()
|
|
|
+TEST_USER_ID = os.environ.get("MIDJOURNEY_TEST_USER_ID", "").strip()
|
|
|
+TEST_PROMPT = os.environ.get("MIDJOURNEY_TEST_PROMPT", "a red apple on white background --v 6")
|
|
|
+TEST_MODE = os.environ.get("MIDJOURNEY_TEST_MODE", "fast").strip().lower()
|
|
|
+POLL_INTERVAL_S = float(os.environ.get("MIDJOURNEY_POLL_INTERVAL_S", "3"))
|
|
|
+POLL_MAX_WAIT_S = float(os.environ.get("MIDJOURNEY_POLL_MAX_WAIT_S", "600"))
|
|
|
+
|
|
|
+
|
|
|
+def run_tool(tool_id: str, params: dict[str, Any], timeout: float = 120.0) -> Any:
|
|
|
+ resp = httpx.post(
|
|
|
+ f"{ROUTER_URL}/run_tool",
|
|
|
+ json={"tool_id": tool_id, "params": params},
|
|
|
+ timeout=timeout,
|
|
|
+ )
|
|
|
+ resp.raise_for_status()
|
|
|
+ body = resp.json()
|
|
|
+ if body.get("status") != "success":
|
|
|
+ raise RuntimeError(body.get("error") or str(body))
|
|
|
+ result = body.get("result")
|
|
|
+ if isinstance(result, dict) and result.get("status") == "error":
|
|
|
+ raise RuntimeError(result.get("error", str(result)))
|
|
|
+ return result
|
|
|
+
|
|
|
+
|
|
|
+def _extract_job_id(data: dict[str, Any]) -> str | None:
|
|
|
+ if not isinstance(data, dict):
|
|
|
+ return None
|
|
|
+ for key in ("job_id", "jobId", "id", "task_id", "taskId"):
|
|
|
+ v = data.get(key)
|
|
|
+ if v is not None and str(v).strip():
|
|
|
+ return str(v).strip()
|
|
|
+ inner = data.get("data")
|
|
|
+ if isinstance(inner, dict):
|
|
|
+ return _extract_job_id(inner)
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+def _status_terminal_ok(data: dict[str, Any]) -> bool:
|
|
|
+ if not isinstance(data, dict):
|
|
|
+ return False
|
|
|
+ s = str(
|
|
|
+ data.get("status")
|
|
|
+ or data.get("job_status")
|
|
|
+ or data.get("jobStatus")
|
|
|
+ or data.get("state")
|
|
|
+ or ""
|
|
|
+ ).lower()
|
|
|
+ if not s and isinstance(data.get("data"), dict):
|
|
|
+ return _status_terminal_ok(data["data"])
|
|
|
+ return any(k in s for k in ("complete", "success", "done", "finished", "succeed", "ready"))
|
|
|
+
|
|
|
+
|
|
|
+def _status_terminal_fail(data: dict[str, Any]) -> bool:
|
|
|
+ if not isinstance(data, dict):
|
|
|
+ return False
|
|
|
+ s = str(data.get("status") or data.get("job_status") or data.get("state") or "").lower()
|
|
|
+ return any(k in s for k in ("fail", "error", "cancel", "canceled", "cancelled"))
|
|
|
+
|
|
|
+
|
|
|
+def _extract_url_list(payload: Any) -> list[str]:
|
|
|
+ if isinstance(payload, list):
|
|
|
+ return [str(x) for x in payload if isinstance(x, str) and x.startswith("http")]
|
|
|
+ if not isinstance(payload, dict):
|
|
|
+ return []
|
|
|
+ for key in ("image_urls", "urls", "images", "data"):
|
|
|
+ v = payload.get(key)
|
|
|
+ if isinstance(v, list):
|
|
|
+ out = [str(x) for x in v if isinstance(x, str) and x.startswith("http")]
|
|
|
+ if out:
|
|
|
+ return out
|
|
|
+ if isinstance(v, dict):
|
|
|
+ nested = _extract_url_list(v)
|
|
|
+ if nested:
|
|
|
+ return nested
|
|
|
+ return _extract_url_list(payload.get("data"))
|
|
|
+
|
|
|
+
|
|
|
+def main() -> None:
|
|
|
+ print("=" * 50)
|
|
|
+ print("测试 Midjourney(submit / query / get_image_urls)")
|
|
|
+ print("=" * 50)
|
|
|
+ print(f"ROUTER_URL: {ROUTER_URL}")
|
|
|
+
|
|
|
+ try:
|
|
|
+ r = httpx.get(f"{ROUTER_URL}/health", timeout=3)
|
|
|
+ print(f"Router 状态: {r.json()}")
|
|
|
+ except httpx.ConnectError:
|
|
|
+ print(f"无法连接 Router ({ROUTER_URL}),请先: uv run python -m tool_agent")
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+ print("\n--- 校验工具已注册 ---")
|
|
|
+ tr = httpx.get(f"{ROUTER_URL}/tools", timeout=30)
|
|
|
+ tr.raise_for_status()
|
|
|
+ tools = tr.json().get("tools", [])
|
|
|
+ ids = {t["tool_id"] for t in tools}
|
|
|
+ for tid in (T_SUBMIT, T_QUERY, T_URLS):
|
|
|
+ if tid not in ids:
|
|
|
+ print(f"错误: {tid!r} 不在 GET /tools 中。示例: {sorted(ids)[:25]}...")
|
|
|
+ sys.exit(1)
|
|
|
+ meta = next(t for t in tools if t["tool_id"] == tid)
|
|
|
+ print(f" {tid}: {meta.get('name', '')} (state={meta.get('state')})")
|
|
|
+
|
|
|
+ if not TEST_COOKIE or not TEST_USER_ID:
|
|
|
+ print(
|
|
|
+ "\n未设置 MIDJOURNEY_TEST_COOKIE 与 MIDJOURNEY_TEST_USER_ID,跳过端到端;"
|
|
|
+ "工具注册检查已通过,退出 0。"
|
|
|
+ )
|
|
|
+ return
|
|
|
+
|
|
|
+ if TEST_MODE not in ("relaxed", "fast"):
|
|
|
+ print(f"错误: MIDJOURNEY_TEST_MODE 须为 relaxed 或 fast,当前: {TEST_MODE!r}")
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+ print("\n--- midjourney_submit_job ---")
|
|
|
+ try:
|
|
|
+ sub = run_tool(
|
|
|
+ T_SUBMIT,
|
|
|
+ {
|
|
|
+ "cookie": TEST_COOKIE,
|
|
|
+ "prompt": TEST_PROMPT,
|
|
|
+ "user_id": TEST_USER_ID,
|
|
|
+ "mode": TEST_MODE,
|
|
|
+ },
|
|
|
+ timeout=180.0,
|
|
|
+ )
|
|
|
+ except (RuntimeError, httpx.HTTPError) as e:
|
|
|
+ print(f"错误: {e}")
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+ if not isinstance(sub, dict):
|
|
|
+ print(f"错误: submit 返回非 object: {type(sub)}")
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+ job_id = _extract_job_id(sub)
|
|
|
+ if not job_id:
|
|
|
+ print(f"错误: 无法从 submit 响应解析 job_id: {sub}")
|
|
|
+ sys.exit(1)
|
|
|
+ print(f"job_id: {job_id}")
|
|
|
+
|
|
|
+ print("\n--- midjourney_query_job_status 轮询 ---")
|
|
|
+ deadline = time.monotonic() + POLL_MAX_WAIT_S
|
|
|
+ last: dict[str, Any] = {}
|
|
|
+
|
|
|
+ while time.monotonic() < deadline:
|
|
|
+ time.sleep(POLL_INTERVAL_S)
|
|
|
+ try:
|
|
|
+ q = run_tool(
|
|
|
+ T_QUERY,
|
|
|
+ {"cookie": TEST_COOKIE, "job_id": job_id},
|
|
|
+ timeout=120.0,
|
|
|
+ )
|
|
|
+ except (RuntimeError, httpx.HTTPError) as e:
|
|
|
+ print(f"轮询错误: {e}")
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+ last = q if isinstance(q, dict) else {}
|
|
|
+ st = last.get("status") or last.get("job_status") or last.get("state")
|
|
|
+ print(f" status: {st}")
|
|
|
+
|
|
|
+ if _status_terminal_fail(last):
|
|
|
+ print(f"任务失败: {last}")
|
|
|
+ sys.exit(1)
|
|
|
+ if _status_terminal_ok(last):
|
|
|
+ break
|
|
|
+ else:
|
|
|
+ print(f"等待超时 ({POLL_MAX_WAIT_S}s),最后响应: {last}")
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+ print("\n--- midjourney_get_image_urls ---")
|
|
|
+ try:
|
|
|
+ urls_payload = run_tool(T_URLS, {"job_id": job_id}, timeout=120.0)
|
|
|
+ except (RuntimeError, httpx.HTTPError) as e:
|
|
|
+ print(f"错误: {e}")
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+ urls = _extract_url_list(urls_payload)
|
|
|
+ if len(urls) < 4:
|
|
|
+ print(f"警告: 期望至少 4 个 http 链接,实际 {len(urls)};原始: {str(urls_payload)[:500]}")
|
|
|
+ if len(urls) == 0:
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+ for i, u in enumerate(urls[:4], 1):
|
|
|
+ print(f" [{i}] {u[:96]}...")
|
|
|
+ print("\n测试通过!")
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|