"""测试 Midjourney 代理工具 — Router POST /run_tool 本地服务将 JSON 原样转发至 MIDJOURNEY_API_BASE 上你已实现的三个接口: POST /submit_job cookie, prompt, user_id, mode(relaxed|fast) POST /query_job_status cookie, job_id POST /get_image_urls job_id → 四张图链接 用法: 1. tools/local/midjourney/.env:MIDJOURNEY_API_BASE 2. uv run python -m tool_agent 3. uv run python tests/test_midjourney.py 端到端(可选):设置 MIDJOURNEY_TEST_COOKIE、MIDJOURNEY_TEST_USER_ID 后脚本会 submit → 轮询 query → get_image_urls;否则仅校验工具已注册并退出 0。 环境变量: TOOL_AGENT_ROUTER_URL MIDJOURNEY_SUBMIT_TOOL_ID 默认 midjourney_submit_job MIDJOURNEY_QUERY_TOOL_ID 默认 midjourney_query_job_status MIDJOURNEY_GET_URLS_TOOL_ID 默认 midjourney_get_image_urls MIDJOURNEY_TEST_COOKIE / MIDJOURNEY_TEST_USER_ID / MIDJOURNEY_TEST_PROMPT / MIDJOURNEY_TEST_MODE MIDJOURNEY_POLL_INTERVAL_S / MIDJOURNEY_POLL_MAX_WAIT_S """ from __future__ import annotations import io import os import sys import time from typing import Any if sys.platform == "win32": _out = sys.stdout if isinstance(_out, io.TextIOWrapper): _out.reconfigure(encoding="utf-8") import httpx ROUTER_URL = os.environ.get("TOOL_AGENT_ROUTER_URL", "http://127.0.0.1:8001") T_SUBMIT = os.environ.get("MIDJOURNEY_SUBMIT_TOOL_ID", "midjourney_submit_job") T_QUERY = os.environ.get("MIDJOURNEY_QUERY_TOOL_ID", "midjourney_query_job_status") T_URLS = os.environ.get("MIDJOURNEY_GET_URLS_TOOL_ID", "midjourney_get_image_urls") TEST_COOKIE = os.environ.get("MIDJOURNEY_TEST_COOKIE", "").strip() TEST_USER_ID = os.environ.get("MIDJOURNEY_TEST_USER_ID", "").strip() TEST_PROMPT = os.environ.get("MIDJOURNEY_TEST_PROMPT", "a red apple on white background --v 6") TEST_MODE = os.environ.get("MIDJOURNEY_TEST_MODE", "fast").strip().lower() POLL_INTERVAL_S = float(os.environ.get("MIDJOURNEY_POLL_INTERVAL_S", "3")) POLL_MAX_WAIT_S = float(os.environ.get("MIDJOURNEY_POLL_MAX_WAIT_S", "600")) def run_tool(tool_id: str, params: dict[str, Any], timeout: float = 120.0) -> Any: resp = httpx.post( f"{ROUTER_URL}/run_tool", json={"tool_id": tool_id, "params": params}, timeout=timeout, ) resp.raise_for_status() body = resp.json() if body.get("status") != "success": raise RuntimeError(body.get("error") or str(body)) result = body.get("result") if isinstance(result, dict) and result.get("status") == "error": raise RuntimeError(result.get("error", str(result))) return result def _extract_job_id(data: dict[str, Any]) -> str | None: if not isinstance(data, dict): return None for key in ("job_id", "jobId", "id", "task_id", "taskId"): v = data.get(key) if v is not None and str(v).strip(): return str(v).strip() inner = data.get("data") if isinstance(inner, dict): return _extract_job_id(inner) return None def _status_terminal_ok(data: dict[str, Any]) -> bool: if not isinstance(data, dict): return False s = str( data.get("status") or data.get("job_status") or data.get("jobStatus") or data.get("state") or "" ).lower() if not s and isinstance(data.get("data"), dict): return _status_terminal_ok(data["data"]) return any(k in s for k in ("complete", "success", "done", "finished", "succeed", "ready")) def _status_terminal_fail(data: dict[str, Any]) -> bool: if not isinstance(data, dict): return False s = str(data.get("status") or data.get("job_status") or data.get("state") or "").lower() return any(k in s for k in ("fail", "error", "cancel", "canceled", "cancelled")) def _extract_url_list(payload: Any) -> list[str]: if isinstance(payload, list): return [str(x) for x in payload if isinstance(x, str) and x.startswith("http")] if not isinstance(payload, dict): return [] for key in ("image_urls", "urls", "images", "data"): v = payload.get(key) if isinstance(v, list): out = [str(x) for x in v if isinstance(x, str) and x.startswith("http")] if out: return out if isinstance(v, dict): nested = _extract_url_list(v) if nested: return nested return _extract_url_list(payload.get("data")) def main() -> None: print("=" * 50) print("测试 Midjourney(submit / query / get_image_urls)") print("=" * 50) print(f"ROUTER_URL: {ROUTER_URL}") try: r = httpx.get(f"{ROUTER_URL}/health", timeout=3) print(f"Router 状态: {r.json()}") except httpx.ConnectError: print(f"无法连接 Router ({ROUTER_URL}),请先: uv run python -m tool_agent") sys.exit(1) print("\n--- 校验工具已注册 ---") tr = httpx.get(f"{ROUTER_URL}/tools", timeout=30) tr.raise_for_status() tools = tr.json().get("tools", []) ids = {t["tool_id"] for t in tools} for tid in (T_SUBMIT, T_QUERY, T_URLS): if tid not in ids: print(f"错误: {tid!r} 不在 GET /tools 中。示例: {sorted(ids)[:25]}...") sys.exit(1) meta = next(t for t in tools if t["tool_id"] == tid) print(f" {tid}: {meta.get('name', '')} (state={meta.get('state')})") if not TEST_COOKIE or not TEST_USER_ID: print( "\n未设置 MIDJOURNEY_TEST_COOKIE 与 MIDJOURNEY_TEST_USER_ID,跳过端到端;" "工具注册检查已通过,退出 0。" ) return if TEST_MODE not in ("relaxed", "fast"): print(f"错误: MIDJOURNEY_TEST_MODE 须为 relaxed 或 fast,当前: {TEST_MODE!r}") sys.exit(1) print("\n--- midjourney_submit_job ---") try: sub = run_tool( T_SUBMIT, { "cookie": TEST_COOKIE, "prompt": TEST_PROMPT, "user_id": TEST_USER_ID, "mode": TEST_MODE, }, timeout=180.0, ) except (RuntimeError, httpx.HTTPError) as e: print(f"错误: {e}") sys.exit(1) if not isinstance(sub, dict): print(f"错误: submit 返回非 object: {type(sub)}") sys.exit(1) job_id = _extract_job_id(sub) if not job_id: print(f"错误: 无法从 submit 响应解析 job_id: {sub}") sys.exit(1) print(f"job_id: {job_id}") print("\n--- midjourney_query_job_status 轮询 ---") deadline = time.monotonic() + POLL_MAX_WAIT_S last: dict[str, Any] = {} while time.monotonic() < deadline: time.sleep(POLL_INTERVAL_S) try: q = run_tool( T_QUERY, {"cookie": TEST_COOKIE, "job_id": job_id}, timeout=120.0, ) except (RuntimeError, httpx.HTTPError) as e: print(f"轮询错误: {e}") sys.exit(1) last = q if isinstance(q, dict) else {} st = last.get("status") or last.get("job_status") or last.get("state") print(f" status: {st}") if _status_terminal_fail(last): print(f"任务失败: {last}") sys.exit(1) if _status_terminal_ok(last): break else: print(f"等待超时 ({POLL_MAX_WAIT_S}s),最后响应: {last}") sys.exit(1) print("\n--- midjourney_get_image_urls ---") try: urls_payload = run_tool(T_URLS, {"job_id": job_id}, timeout=120.0) except (RuntimeError, httpx.HTTPError) as e: print(f"错误: {e}") sys.exit(1) urls = _extract_url_list(urls_payload) if len(urls) < 4: print(f"警告: 期望至少 4 个 http 链接,实际 {len(urls)};原始: {str(urls_payload)[:500]}") if len(urls) == 0: sys.exit(1) for i, u in enumerate(urls[:4], 1): print(f" [{i}] {u[:96]}...") print("\n测试通过!") if __name__ == "__main__": main()