kevin.yang 4 дней назад
Родитель
Сommit
68d1a203b2

+ 17 - 0
data/groups.json

@@ -46,6 +46,23 @@
         "flux_query"
       ],
       "usage_example": "1. flux_submit 传入 model(如 flux-2-pro-preview)、prompt,得到 id 与 polling_url\n2. 反复调用 flux_query 传入 polling_url、request_id(即 id),直到 status 为 Ready,从 result.sample 取图(签名 URL 约 10 分钟有效)"
+    },
+    {
+      "group_id": "midjourney_lifecycle",
+      "name": "Midjourney(提交、查状态、取图)",
+      "description": "先提交任务,轮询状态完成后获取四张图链接",
+      "category": "remote",
+      "tool_ids": [
+        "midjourney_submit_job",
+        "midjourney_query_job_status",
+        "midjourney_get_image_urls"
+      ],
+      "usage_order": [
+        "midjourney_submit_job",
+        "midjourney_query_job_status",
+        "midjourney_get_image_urls"
+      ],
+      "usage_example": "1. midjourney_submit_job:cookie、prompt、user_id、mode(relaxed|fast) 得 job_id\n2. 轮询 midjourney_query_job_status:cookie、job_id 直至完成\n3. midjourney_get_image_urls:job_id 取四张图 URL"
     }
   ],
   "version": "1.0"

+ 122 - 0
data/registry.json

@@ -653,6 +653,128 @@
       "group_ids": [
         "flux_bfl_lifecycle"
       ]
+    },
+    {
+      "tool_id": "midjourney_submit_job",
+      "name": "Midjourney-提交生图任务",
+      "tool_slug_ids": [],
+      "category": "cv",
+      "description": "提交 Midjourney 生图任务(转发至 MIDJOURNEY_API_BASE/submit_job)。需配置 tools/local/midjourney/.env。mode 为 relaxed 或 fast。",
+      "input_schema": {
+        "type": "object",
+        "properties": {
+          "cookie": {
+            "type": "string",
+            "description": "Midjourney 会话 cookie"
+          },
+          "prompt": {
+            "type": "string",
+            "description": "提示词"
+          },
+          "user_id": {
+            "type": "string",
+            "description": "用户 ID"
+          },
+          "mode": {
+            "type": "string",
+            "enum": [
+              "relaxed",
+              "fast"
+            ],
+            "description": "队列模式:relaxed 或 fast"
+          }
+        },
+        "required": [
+          "cookie",
+          "prompt",
+          "user_id",
+          "mode"
+        ]
+      },
+      "output_schema": {
+        "type": "object",
+        "description": "上游返回 JSON,通常含 job_id 或 id",
+        "properties": {}
+      },
+      "stream_support": false,
+      "status": "active",
+      "backend_runtime": "local",
+      "group_ids": [
+        "midjourney_lifecycle"
+      ]
+    },
+    {
+      "tool_id": "midjourney_query_job_status",
+      "name": "Midjourney-查询任务状态",
+      "tool_slug_ids": [],
+      "category": "cv",
+      "description": "查询指定任务状态(转发 MIDJOURNEY_API_BASE/query_job_status)。",
+      "input_schema": {
+        "type": "object",
+        "properties": {
+          "cookie": {
+            "type": "string",
+            "description": "Midjourney 会话 cookie"
+          },
+          "job_id": {
+            "type": "string",
+            "description": "submit_job 返回的任务 ID"
+          }
+        },
+        "required": [
+          "cookie",
+          "job_id"
+        ]
+      },
+      "output_schema": {
+        "type": "object",
+        "description": "上游状态 JSON,如 status / job_status 等",
+        "properties": {}
+      },
+      "stream_support": false,
+      "status": "active",
+      "backend_runtime": "local",
+      "group_ids": [
+        "midjourney_lifecycle"
+      ]
+    },
+    {
+      "tool_id": "midjourney_get_image_urls",
+      "name": "Midjourney-获取结果图链接",
+      "tool_slug_ids": [],
+      "category": "cv",
+      "description": "根据 job_id 获取 4 张图 URL(转发 MIDJOURNEY_API_BASE/get_image_urls)。",
+      "input_schema": {
+        "type": "object",
+        "properties": {
+          "job_id": {
+            "type": "string",
+            "description": "任务 ID"
+          }
+        },
+        "required": [
+          "job_id"
+        ]
+      },
+      "output_schema": {
+        "type": "object",
+        "description": "上游返回 JSON 或 URL 数组,常见字段 image_urls / urls",
+        "properties": {
+          "image_urls": {
+            "type": "array",
+            "items": {
+              "type": "string"
+            },
+            "description": "四张图片链接(字段名以实际服务为准)"
+          }
+        }
+      },
+      "stream_support": false,
+      "status": "active",
+      "backend_runtime": "local",
+      "group_ids": [
+        "midjourney_lifecycle"
+      ]
     }
   ],
   "version": "2.0"

+ 42 - 0
data/sources.json

@@ -139,6 +139,48 @@
         "http_method": "POST",
         "internal_port": 0
       }
+    ],
+    "midjourney_submit_job": [
+      {
+        "type": "local",
+        "host_dir": "tools/local/midjourney",
+        "container_id": "",
+        "image": "",
+        "hub_url": "",
+        "hub_tool_path": "",
+        "hub_api_key": "",
+        "endpoint_path": "/submit_job",
+        "http_method": "POST",
+        "internal_port": 0
+      }
+    ],
+    "midjourney_query_job_status": [
+      {
+        "type": "local",
+        "host_dir": "tools/local/midjourney",
+        "container_id": "",
+        "image": "",
+        "hub_url": "",
+        "hub_tool_path": "",
+        "hub_api_key": "",
+        "endpoint_path": "/query_job_status",
+        "http_method": "POST",
+        "internal_port": 0
+      }
+    ],
+    "midjourney_get_image_urls": [
+      {
+        "type": "local",
+        "host_dir": "tools/local/midjourney",
+        "container_id": "",
+        "image": "",
+        "hub_url": "",
+        "hub_tool_path": "",
+        "hub_api_key": "",
+        "endpoint_path": "/get_image_urls",
+        "http_method": "POST",
+        "internal_port": 0
+      }
     ]
   }
 }

+ 1 - 0
pyproject.toml

@@ -44,4 +44,5 @@ members = [
     "tools/local/ji_meng",
     "tools/local/nano_banana",
     "tools/local/flux",
+    "tools/local/midjourney",
 ]

+ 231 - 0
tests/test_midjourney.py

@@ -0,0 +1,231 @@
+"""测试 Midjourney 代理工具 — Router POST /run_tool
+
+本地服务将 JSON 原样转发至 MIDJOURNEY_API_BASE 上你已实现的三个接口:
+  POST /submit_job       cookie, prompt, user_id, mode(relaxed|fast)
+  POST /query_job_status cookie, job_id
+  POST /get_image_urls   job_id → 四张图链接
+
+用法:
+    1. tools/local/midjourney/.env:MIDJOURNEY_API_BASE
+    2. uv run python -m tool_agent
+    3. uv run python tests/test_midjourney.py
+
+端到端(可选):设置 MIDJOURNEY_TEST_COOKIE、MIDJOURNEY_TEST_USER_ID 后脚本会
+    submit → 轮询 query → get_image_urls;否则仅校验工具已注册并退出 0。
+
+环境变量:
+    TOOL_AGENT_ROUTER_URL
+    MIDJOURNEY_SUBMIT_TOOL_ID      默认 midjourney_submit_job
+    MIDJOURNEY_QUERY_TOOL_ID       默认 midjourney_query_job_status
+    MIDJOURNEY_GET_URLS_TOOL_ID    默认 midjourney_get_image_urls
+    MIDJOURNEY_TEST_COOKIE / MIDJOURNEY_TEST_USER_ID / MIDJOURNEY_TEST_PROMPT / MIDJOURNEY_TEST_MODE
+    MIDJOURNEY_POLL_INTERVAL_S / MIDJOURNEY_POLL_MAX_WAIT_S
+"""
+
+from __future__ import annotations
+
+import io
+import os
+import sys
+import time
+from typing import Any
+
+if sys.platform == "win32":
+    _out = sys.stdout
+    if isinstance(_out, io.TextIOWrapper):
+        _out.reconfigure(encoding="utf-8")
+
+import httpx
+
+ROUTER_URL = os.environ.get("TOOL_AGENT_ROUTER_URL", "http://127.0.0.1:8001")
+T_SUBMIT = os.environ.get("MIDJOURNEY_SUBMIT_TOOL_ID", "midjourney_submit_job")
+T_QUERY = os.environ.get("MIDJOURNEY_QUERY_TOOL_ID", "midjourney_query_job_status")
+T_URLS = os.environ.get("MIDJOURNEY_GET_URLS_TOOL_ID", "midjourney_get_image_urls")
+TEST_COOKIE = os.environ.get("MIDJOURNEY_TEST_COOKIE", "").strip()
+TEST_USER_ID = os.environ.get("MIDJOURNEY_TEST_USER_ID", "").strip()
+TEST_PROMPT = os.environ.get("MIDJOURNEY_TEST_PROMPT", "a red apple on white background --v 6")
+TEST_MODE = os.environ.get("MIDJOURNEY_TEST_MODE", "fast").strip().lower()
+POLL_INTERVAL_S = float(os.environ.get("MIDJOURNEY_POLL_INTERVAL_S", "3"))
+POLL_MAX_WAIT_S = float(os.environ.get("MIDJOURNEY_POLL_MAX_WAIT_S", "600"))
+
+
+def run_tool(tool_id: str, params: dict[str, Any], timeout: float = 120.0) -> Any:
+    resp = httpx.post(
+        f"{ROUTER_URL}/run_tool",
+        json={"tool_id": tool_id, "params": params},
+        timeout=timeout,
+    )
+    resp.raise_for_status()
+    body = resp.json()
+    if body.get("status") != "success":
+        raise RuntimeError(body.get("error") or str(body))
+    result = body.get("result")
+    if isinstance(result, dict) and result.get("status") == "error":
+        raise RuntimeError(result.get("error", str(result)))
+    return result
+
+
+def _extract_job_id(data: dict[str, Any]) -> str | None:
+    if not isinstance(data, dict):
+        return None
+    for key in ("job_id", "jobId", "id", "task_id", "taskId"):
+        v = data.get(key)
+        if v is not None and str(v).strip():
+            return str(v).strip()
+    inner = data.get("data")
+    if isinstance(inner, dict):
+        return _extract_job_id(inner)
+    return None
+
+
+def _status_terminal_ok(data: dict[str, Any]) -> bool:
+    if not isinstance(data, dict):
+        return False
+    s = str(
+        data.get("status")
+        or data.get("job_status")
+        or data.get("jobStatus")
+        or data.get("state")
+        or ""
+    ).lower()
+    if not s and isinstance(data.get("data"), dict):
+        return _status_terminal_ok(data["data"])
+    return any(k in s for k in ("complete", "success", "done", "finished", "succeed", "ready"))
+
+
+def _status_terminal_fail(data: dict[str, Any]) -> bool:
+    if not isinstance(data, dict):
+        return False
+    s = str(data.get("status") or data.get("job_status") or data.get("state") or "").lower()
+    return any(k in s for k in ("fail", "error", "cancel", "canceled", "cancelled"))
+
+
+def _extract_url_list(payload: Any) -> list[str]:
+    if isinstance(payload, list):
+        return [str(x) for x in payload if isinstance(x, str) and x.startswith("http")]
+    if not isinstance(payload, dict):
+        return []
+    for key in ("image_urls", "urls", "images", "data"):
+        v = payload.get(key)
+        if isinstance(v, list):
+            out = [str(x) for x in v if isinstance(x, str) and x.startswith("http")]
+            if out:
+                return out
+        if isinstance(v, dict):
+            nested = _extract_url_list(v)
+            if nested:
+                return nested
+    return _extract_url_list(payload.get("data"))
+
+
+def main() -> None:
+    print("=" * 50)
+    print("测试 Midjourney(submit / query / get_image_urls)")
+    print("=" * 50)
+    print(f"ROUTER_URL: {ROUTER_URL}")
+
+    try:
+        r = httpx.get(f"{ROUTER_URL}/health", timeout=3)
+        print(f"Router 状态: {r.json()}")
+    except httpx.ConnectError:
+        print(f"无法连接 Router ({ROUTER_URL}),请先: uv run python -m tool_agent")
+        sys.exit(1)
+
+    print("\n--- 校验工具已注册 ---")
+    tr = httpx.get(f"{ROUTER_URL}/tools", timeout=30)
+    tr.raise_for_status()
+    tools = tr.json().get("tools", [])
+    ids = {t["tool_id"] for t in tools}
+    for tid in (T_SUBMIT, T_QUERY, T_URLS):
+        if tid not in ids:
+            print(f"错误: {tid!r} 不在 GET /tools 中。示例: {sorted(ids)[:25]}...")
+            sys.exit(1)
+        meta = next(t for t in tools if t["tool_id"] == tid)
+        print(f"  {tid}: {meta.get('name', '')} (state={meta.get('state')})")
+
+    if not TEST_COOKIE or not TEST_USER_ID:
+        print(
+            "\n未设置 MIDJOURNEY_TEST_COOKIE 与 MIDJOURNEY_TEST_USER_ID,跳过端到端;"
+            "工具注册检查已通过,退出 0。"
+        )
+        return
+
+    if TEST_MODE not in ("relaxed", "fast"):
+        print(f"错误: MIDJOURNEY_TEST_MODE 须为 relaxed 或 fast,当前: {TEST_MODE!r}")
+        sys.exit(1)
+
+    print("\n--- midjourney_submit_job ---")
+    try:
+        sub = run_tool(
+            T_SUBMIT,
+            {
+                "cookie": TEST_COOKIE,
+                "prompt": TEST_PROMPT,
+                "user_id": TEST_USER_ID,
+                "mode": TEST_MODE,
+            },
+            timeout=180.0,
+        )
+    except (RuntimeError, httpx.HTTPError) as e:
+        print(f"错误: {e}")
+        sys.exit(1)
+
+    if not isinstance(sub, dict):
+        print(f"错误: submit 返回非 object: {type(sub)}")
+        sys.exit(1)
+
+    job_id = _extract_job_id(sub)
+    if not job_id:
+        print(f"错误: 无法从 submit 响应解析 job_id: {sub}")
+        sys.exit(1)
+    print(f"job_id: {job_id}")
+
+    print("\n--- midjourney_query_job_status 轮询 ---")
+    deadline = time.monotonic() + POLL_MAX_WAIT_S
+    last: dict[str, Any] = {}
+
+    while time.monotonic() < deadline:
+        time.sleep(POLL_INTERVAL_S)
+        try:
+            q = run_tool(
+                T_QUERY,
+                {"cookie": TEST_COOKIE, "job_id": job_id},
+                timeout=120.0,
+            )
+        except (RuntimeError, httpx.HTTPError) as e:
+            print(f"轮询错误: {e}")
+            sys.exit(1)
+
+        last = q if isinstance(q, dict) else {}
+        st = last.get("status") or last.get("job_status") or last.get("state")
+        print(f"  status: {st}")
+
+        if _status_terminal_fail(last):
+            print(f"任务失败: {last}")
+            sys.exit(1)
+        if _status_terminal_ok(last):
+            break
+    else:
+        print(f"等待超时 ({POLL_MAX_WAIT_S}s),最后响应: {last}")
+        sys.exit(1)
+
+    print("\n--- midjourney_get_image_urls ---")
+    try:
+        urls_payload = run_tool(T_URLS, {"job_id": job_id}, timeout=120.0)
+    except (RuntimeError, httpx.HTTPError) as e:
+        print(f"错误: {e}")
+        sys.exit(1)
+
+    urls = _extract_url_list(urls_payload)
+    if len(urls) < 4:
+        print(f"警告: 期望至少 4 个 http 链接,实际 {len(urls)};原始: {str(urls_payload)[:500]}")
+        if len(urls) == 0:
+            sys.exit(1)
+
+    for i, u in enumerate(urls[:4], 1):
+        print(f"  [{i}] {u[:96]}...")
+    print("\n测试通过!")
+
+
+if __name__ == "__main__":
+    main()

+ 2 - 0
tools/local/midjourney/.env.example

@@ -0,0 +1,2 @@
+# 你已部署的 Midjourney 服务根地址(将请求 POST 到 {BASE}/submit_job 等)
+MIDJOURNEY_API_BASE=https://your-mj-api.example.com

+ 3 - 0
tools/local/midjourney/.gitignore

@@ -0,0 +1,3 @@
+.env
+.venv/
+__pycache__/

+ 99 - 0
tools/local/midjourney/main.py

@@ -0,0 +1,99 @@
+"""Midjourney 本地代理 — 三个 POST 与上游 JSON 对齐。
+
+环境变量:
+  MIDJOURNEY_API_BASE  必填,例如 https://your-host(后接 /submit_job 等)
+
+接口(与 Router 注册一致):
+  GET  /health
+  POST /submit_job        cookie, prompt, user_id, mode ∈ relaxed|fast
+  POST /query_job_status  cookie, job_id
+  POST /get_image_urls    job_id
+"""
+
+from __future__ import annotations
+
+import argparse
+from typing import Any, Literal
+
+import uvicorn
+from fastapi import FastAPI, HTTPException
+from pydantic import BaseModel, Field
+
+from midjourney_client import forward_post
+
+app = FastAPI(title="Midjourney API Proxy")
+
+
+class SubmitJobRequest(BaseModel):
+    cookie: str = Field(..., description="Midjourney 会话 cookie")
+    prompt: str = Field(..., description="提示词")
+    user_id: str = Field(..., description="用户 ID")
+    mode: Literal["relaxed", "fast"] = Field(..., description="relaxed 或 fast")
+
+
+class QueryJobStatusRequest(BaseModel):
+    cookie: str = Field(..., description="Midjourney 会话 cookie")
+    job_id: str = Field(..., description="submit_job 返回的任务 ID")
+
+
+class GetImageUrlsRequest(BaseModel):
+    job_id: str = Field(..., description="任务 ID")
+
+
+@app.get("/health")
+def health() -> dict[str, str]:
+    return {"status": "ok"}
+
+
+@app.post("/submit_job")
+def submit_job(req: SubmitJobRequest) -> Any:
+    try:
+        return forward_post(
+            "/submit_job",
+            {
+                "cookie": req.cookie,
+                "prompt": req.prompt,
+                "user_id": req.user_id,
+                "mode": req.mode,
+            },
+        )
+    except ValueError as e:
+        raise HTTPException(status_code=503, detail=str(e)) from e
+    except RuntimeError as e:
+        raise HTTPException(status_code=502, detail=str(e)) from e
+    except Exception as e:
+        raise HTTPException(status_code=502, detail=str(e)) from e
+
+
+@app.post("/query_job_status")
+def query_job_status(req: QueryJobStatusRequest) -> Any:
+    try:
+        return forward_post(
+            "/query_job_status",
+            {"cookie": req.cookie, "job_id": req.job_id},
+        )
+    except ValueError as e:
+        raise HTTPException(status_code=503, detail=str(e)) from e
+    except RuntimeError as e:
+        raise HTTPException(status_code=502, detail=str(e)) from e
+    except Exception as e:
+        raise HTTPException(status_code=502, detail=str(e)) from e
+
+
+@app.post("/get_image_urls")
+def get_image_urls(req: GetImageUrlsRequest) -> Any:
+    try:
+        return forward_post("/get_image_urls", {"job_id": req.job_id})
+    except ValueError as e:
+        raise HTTPException(status_code=503, detail=str(e)) from e
+    except RuntimeError as e:
+        raise HTTPException(status_code=502, detail=str(e)) from e
+    except Exception as e:
+        raise HTTPException(status_code=502, detail=str(e)) from e
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--port", type=int, default=8001)
+    args = parser.parse_args()
+    uvicorn.run(app, host="0.0.0.0", port=args.port)

+ 46 - 0
tools/local/midjourney/midjourney_client.py

@@ -0,0 +1,46 @@
+"""将请求转发到自建 Midjourney HTTP 服务(与 tools/local 约定一致)。"""
+
+from __future__ import annotations
+
+import os
+from typing import Any
+
+import httpx
+from dotenv import load_dotenv
+
+_ = load_dotenv()
+
+
+def _base_url() -> str:
+    base = os.environ.get("MIDJOURNEY_API_BASE", "").strip().rstrip("/")
+    if not base:
+        raise ValueError("缺少环境变量 MIDJOURNEY_API_BASE(上游根 URL,不含尾路径)")
+    return base
+
+
+def forward_post(path: str, json_body: dict[str, Any]) -> Any:
+    """POST {MIDJOURNEY_API_BASE}{path},Content-Type: application/json。"""
+    url = f"{_base_url()}{path if path.startswith('/') else '/' + path}"
+    with httpx.Client(timeout=300.0) as client:
+        r = client.post(
+            url,
+            json=json_body,
+            headers={"accept": "application/json", "Content-Type": "application/json"},
+        )
+        ct = (r.headers.get("content-type") or "").lower()
+        if "application/json" not in ct:
+            r.raise_for_status()
+            raise RuntimeError(f"非 JSON 响应 ({r.status_code}): {r.text[:1500]}")
+        try:
+            data = r.json()
+        except Exception:
+            raise RuntimeError(f"无效 JSON ({r.status_code}): {r.text[:1500]}") from None
+
+    if r.status_code >= 400:
+        if isinstance(data, dict):
+            msg = data.get("detail", data.get("message", data.get("error", str(data))))
+        else:
+            msg = str(data)
+        raise RuntimeError(f"上游 HTTP {r.status_code}: {msg}")
+
+    return data

+ 12 - 0
tools/local/midjourney/pyproject.toml

@@ -0,0 +1,12 @@
+[project]
+name = "midjourney-proxy"
+version = "0.1.0"
+description = "Midjourney 上游 HTTP 代理:submit_job / query_job_status / get_image_urls"
+requires-python = ">=3.11"
+dependencies = [
+    "fastapi>=0.115.0",
+    "uvicorn>=0.30.0",
+    "pydantic>=2.0.0",
+    "python-dotenv>=1.0.0",
+    "httpx>=0.27.0",
+]

+ 22 - 0
uv.lock

@@ -9,6 +9,7 @@ members = [
     "ji-meng",
     "launch-comfy-env",
     "liblibai-controlnet",
+    "midjourney-proxy",
     "nano-banana",
     "runcomfy-stop-env",
     "task-0cd69d84",
@@ -564,6 +565,27 @@ wheels = [
     { url = "https://mirrors.ustc.edu.cn/pypi/packages/fd/d9/eaa1f80170d2b7c5ba23f3b59f766f3a0bb41155fbc32a69adfa1adaaef9/mcp-1.26.0-py3-none-any.whl", hash = "sha256:904a21c33c25aa98ddbeb47273033c435e595bbacfdb177f4bd87f6dceebe1ca", size = 233615, upload-time = "2026-01-24T19:40:30.652Z" },
 ]
 
+[[package]]
+name = "midjourney-proxy"
+version = "0.1.0"
+source = { virtual = "tools/local/midjourney" }
+dependencies = [
+    { name = "fastapi" },
+    { name = "httpx" },
+    { name = "pydantic" },
+    { name = "python-dotenv" },
+    { name = "uvicorn" },
+]
+
+[package.metadata]
+requires-dist = [
+    { name = "fastapi", specifier = ">=0.115.0" },
+    { name = "httpx", specifier = ">=0.27.0" },
+    { name = "pydantic", specifier = ">=2.0.0" },
+    { name = "python-dotenv", specifier = ">=1.0.0" },
+    { name = "uvicorn", specifier = ">=0.30.0" },
+]
+
 [[package]]
 name = "nano-banana"
 version = "0.1.0"