test_midjourney.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. """测试 Midjourney 代理工具 — Router POST /run_tool
  2. 本地服务将 JSON 原样转发至 MIDJOURNEY_API_BASE 上你已实现的三个接口:
  3. POST /submit_job cookie, prompt, user_id, mode(relaxed|fast)
  4. POST /query_job_status cookie, job_id
  5. POST /get_image_urls job_id → 四张图链接
  6. 用法:
  7. 1. tools/local/midjourney/.env:MIDJOURNEY_API_BASE
  8. 2. uv run python -m tool_agent
  9. 3. uv run python tests/test_midjourney.py
  10. 端到端(可选):设置 MIDJOURNEY_TEST_COOKIE、MIDJOURNEY_TEST_USER_ID 后脚本会
  11. submit → 轮询 query → get_image_urls;否则仅校验工具已注册并退出 0。
  12. 环境变量:
  13. TOOL_AGENT_ROUTER_URL
  14. MIDJOURNEY_SUBMIT_TOOL_ID 默认 midjourney_submit_job
  15. MIDJOURNEY_QUERY_TOOL_ID 默认 midjourney_query_job_status
  16. MIDJOURNEY_GET_URLS_TOOL_ID 默认 midjourney_get_image_urls
  17. MIDJOURNEY_TEST_COOKIE / MIDJOURNEY_TEST_USER_ID / MIDJOURNEY_TEST_PROMPT / MIDJOURNEY_TEST_MODE
  18. MIDJOURNEY_POLL_INTERVAL_S / MIDJOURNEY_POLL_MAX_WAIT_S
  19. """
  20. from __future__ import annotations
  21. import io
  22. import os
  23. import sys
  24. import time
  25. from typing import Any
  26. if sys.platform == "win32":
  27. _out = sys.stdout
  28. if isinstance(_out, io.TextIOWrapper):
  29. _out.reconfigure(encoding="utf-8")
  30. import httpx
  31. ROUTER_URL = os.environ.get("TOOL_AGENT_ROUTER_URL", "http://127.0.0.1:8001")
  32. T_SUBMIT = os.environ.get("MIDJOURNEY_SUBMIT_TOOL_ID", "midjourney_submit_job")
  33. T_QUERY = os.environ.get("MIDJOURNEY_QUERY_TOOL_ID", "midjourney_query_job_status")
  34. T_URLS = os.environ.get("MIDJOURNEY_GET_URLS_TOOL_ID", "midjourney_get_image_urls")
  35. TEST_COOKIE = os.environ.get("MIDJOURNEY_TEST_COOKIE", "").strip()
  36. TEST_USER_ID = os.environ.get("MIDJOURNEY_TEST_USER_ID", "").strip()
  37. TEST_PROMPT = os.environ.get("MIDJOURNEY_TEST_PROMPT", "a red apple on white background --v 6")
  38. TEST_MODE = os.environ.get("MIDJOURNEY_TEST_MODE", "fast").strip().lower()
  39. POLL_INTERVAL_S = float(os.environ.get("MIDJOURNEY_POLL_INTERVAL_S", "3"))
  40. POLL_MAX_WAIT_S = float(os.environ.get("MIDJOURNEY_POLL_MAX_WAIT_S", "600"))
  41. def run_tool(tool_id: str, params: dict[str, Any], timeout: float = 120.0) -> Any:
  42. resp = httpx.post(
  43. f"{ROUTER_URL}/run_tool",
  44. json={"tool_id": tool_id, "params": params},
  45. timeout=timeout,
  46. )
  47. resp.raise_for_status()
  48. body = resp.json()
  49. if body.get("status") != "success":
  50. raise RuntimeError(body.get("error") or str(body))
  51. result = body.get("result")
  52. if isinstance(result, dict) and result.get("status") == "error":
  53. raise RuntimeError(result.get("error", str(result)))
  54. return result
  55. def _extract_job_id(data: dict[str, Any]) -> str | None:
  56. if not isinstance(data, dict):
  57. return None
  58. for key in ("job_id", "jobId", "id", "task_id", "taskId"):
  59. v = data.get(key)
  60. if v is not None and str(v).strip():
  61. return str(v).strip()
  62. inner = data.get("data")
  63. if isinstance(inner, dict):
  64. return _extract_job_id(inner)
  65. return None
  66. def _status_terminal_ok(data: dict[str, Any]) -> bool:
  67. if not isinstance(data, dict):
  68. return False
  69. s = str(
  70. data.get("status")
  71. or data.get("job_status")
  72. or data.get("jobStatus")
  73. or data.get("state")
  74. or ""
  75. ).lower()
  76. if not s and isinstance(data.get("data"), dict):
  77. return _status_terminal_ok(data["data"])
  78. return any(k in s for k in ("complete", "success", "done", "finished", "succeed", "ready"))
  79. def _status_terminal_fail(data: dict[str, Any]) -> bool:
  80. if not isinstance(data, dict):
  81. return False
  82. s = str(data.get("status") or data.get("job_status") or data.get("state") or "").lower()
  83. return any(k in s for k in ("fail", "error", "cancel", "canceled", "cancelled"))
  84. def _extract_url_list(payload: Any) -> list[str]:
  85. if isinstance(payload, list):
  86. return [str(x) for x in payload if isinstance(x, str) and x.startswith("http")]
  87. if not isinstance(payload, dict):
  88. return []
  89. for key in ("image_urls", "urls", "images", "data"):
  90. v = payload.get(key)
  91. if isinstance(v, list):
  92. out = [str(x) for x in v if isinstance(x, str) and x.startswith("http")]
  93. if out:
  94. return out
  95. if isinstance(v, dict):
  96. nested = _extract_url_list(v)
  97. if nested:
  98. return nested
  99. return _extract_url_list(payload.get("data"))
  100. def main() -> None:
  101. print("=" * 50)
  102. print("测试 Midjourney(submit / query / get_image_urls)")
  103. print("=" * 50)
  104. print(f"ROUTER_URL: {ROUTER_URL}")
  105. try:
  106. r = httpx.get(f"{ROUTER_URL}/health", timeout=3)
  107. print(f"Router 状态: {r.json()}")
  108. except httpx.ConnectError:
  109. print(f"无法连接 Router ({ROUTER_URL}),请先: uv run python -m tool_agent")
  110. sys.exit(1)
  111. print("\n--- 校验工具已注册 ---")
  112. tr = httpx.get(f"{ROUTER_URL}/tools", timeout=30)
  113. tr.raise_for_status()
  114. tools = tr.json().get("tools", [])
  115. ids = {t["tool_id"] for t in tools}
  116. for tid in (T_SUBMIT, T_QUERY, T_URLS):
  117. if tid not in ids:
  118. print(f"错误: {tid!r} 不在 GET /tools 中。示例: {sorted(ids)[:25]}...")
  119. sys.exit(1)
  120. meta = next(t for t in tools if t["tool_id"] == tid)
  121. print(f" {tid}: {meta.get('name', '')} (state={meta.get('state')})")
  122. if not TEST_COOKIE or not TEST_USER_ID:
  123. print(
  124. "\n未设置 MIDJOURNEY_TEST_COOKIE 与 MIDJOURNEY_TEST_USER_ID,跳过端到端;"
  125. "工具注册检查已通过,退出 0。"
  126. )
  127. return
  128. if TEST_MODE not in ("relaxed", "fast"):
  129. print(f"错误: MIDJOURNEY_TEST_MODE 须为 relaxed 或 fast,当前: {TEST_MODE!r}")
  130. sys.exit(1)
  131. print("\n--- midjourney_submit_job ---")
  132. try:
  133. sub = run_tool(
  134. T_SUBMIT,
  135. {
  136. "cookie": TEST_COOKIE,
  137. "prompt": TEST_PROMPT,
  138. "user_id": TEST_USER_ID,
  139. "mode": TEST_MODE,
  140. },
  141. timeout=180.0,
  142. )
  143. except (RuntimeError, httpx.HTTPError) as e:
  144. print(f"错误: {e}")
  145. sys.exit(1)
  146. if not isinstance(sub, dict):
  147. print(f"错误: submit 返回非 object: {type(sub)}")
  148. sys.exit(1)
  149. job_id = _extract_job_id(sub)
  150. if not job_id:
  151. print(f"错误: 无法从 submit 响应解析 job_id: {sub}")
  152. sys.exit(1)
  153. print(f"job_id: {job_id}")
  154. print("\n--- midjourney_query_job_status 轮询 ---")
  155. deadline = time.monotonic() + POLL_MAX_WAIT_S
  156. last: dict[str, Any] = {}
  157. while time.monotonic() < deadline:
  158. time.sleep(POLL_INTERVAL_S)
  159. try:
  160. q = run_tool(
  161. T_QUERY,
  162. {"cookie": TEST_COOKIE, "job_id": job_id},
  163. timeout=120.0,
  164. )
  165. except (RuntimeError, httpx.HTTPError) as e:
  166. print(f"轮询错误: {e}")
  167. sys.exit(1)
  168. last = q if isinstance(q, dict) else {}
  169. st = last.get("status") or last.get("job_status") or last.get("state")
  170. print(f" status: {st}")
  171. if _status_terminal_fail(last):
  172. print(f"任务失败: {last}")
  173. sys.exit(1)
  174. if _status_terminal_ok(last):
  175. break
  176. else:
  177. print(f"等待超时 ({POLL_MAX_WAIT_S}s),最后响应: {last}")
  178. sys.exit(1)
  179. print("\n--- midjourney_get_image_urls ---")
  180. try:
  181. urls_payload = run_tool(T_URLS, {"job_id": job_id}, timeout=120.0)
  182. except (RuntimeError, httpx.HTTPError) as e:
  183. print(f"错误: {e}")
  184. sys.exit(1)
  185. urls = _extract_url_list(urls_payload)
  186. if len(urls) < 4:
  187. print(f"警告: 期望至少 4 个 http 链接,实际 {len(urls)};原始: {str(urls_payload)[:500]}")
  188. if len(urls) == 0:
  189. sys.exit(1)
  190. for i, u in enumerate(urls[:4], 1):
  191. print(f" [{i}] {u[:96]}...")
  192. print("\n测试通过!")
  193. if __name__ == "__main__":
  194. main()