Explorar o código

修复stitch测试

kevin.yang hai 4 días
pai
achega
95f67713da
Modificáronse 3 ficheiros con 68 adicións e 39 borrados
  1. 11 4
      src/tool_agent/router/dispatcher.py
  2. 1 1
      src/tool_agent/router/server.py
  3. 56 34
      tests/test_router_api.py

+ 11 - 4
src/tool_agent/router/dispatcher.py

@@ -2,6 +2,7 @@
 
 from __future__ import annotations
 
+import asyncio
 import logging
 from typing import Any, TYPE_CHECKING
 
@@ -20,11 +21,17 @@ class Dispatcher:
         self._status_manager = status_manager
 
     async def dispatch(self, tool_id: str, params: dict[str, Any], stream: bool = False) -> dict[str, Any]:
-        """分发调用请求到工具的活跃端点"""
-        # 1. 获取端点信息
-        endpoint = self._status_manager.get_active_endpoint(tool_id)
+        """分发调用请求到工具的活跃端点;无可用端点时先尝试启动(本地 uv 进程等)。"""
+        sm = self._status_manager
+        endpoint = sm.get_active_endpoint(tool_id)
         if not endpoint:
-            return {"status": "error", "error": f"Tool '{tool_id}' is not running or has no active endpoint"}
+            await asyncio.to_thread(sm.start_tool, tool_id)
+            endpoint = sm.get_active_endpoint(tool_id)
+        if not endpoint:
+            route = sm.get_status(tool_id)
+            err = (route.last_error or "").strip() if route else ""
+            msg = err or f"Tool '{tool_id}' is not running or has no active endpoint"
+            return {"status": "error", "error": msg}
 
         # 2. 根据端点类型调用
         try:

+ 1 - 1
src/tool_agent/router/server.py

@@ -118,7 +118,7 @@ def create_app(router: Router, session_manager: SessionManager = None) -> FastAP
 
     @app.post("/run_tool")
     async def run_tool(request: RunToolRequest):
-        """调用已注册的工具"""
+        """调用已注册的工具(Dispatcher 会在需要时自动 start_tool)。"""
         try:
             result = await router.dispatcher.dispatch(request.tool_id, request.params)
             return RunToolResponse(status="success", result=result)

+ 56 - 34
tests/test_router_api.py

@@ -5,7 +5,7 @@
     uv run python tests/test_router_api.py --search                      # 搜索工具列表
     uv run python tests/test_router_api.py --search image                # 关键词搜索
     uv run python tests/test_router_api.py --status                      # 工具运行状态
-    uv run python tests/test_router_api.py --select image_stitcher       # 调用指定工具
+    uv run python tests/test_router_api.py --select image_stitcher       # POST /run_tool 调用工具
     uv run python tests/test_router_api.py --stitch                      # 测试图片拼接
     uv run python tests/test_router_api.py --create                      # 默认任务
     uv run python tests/test_router_api.py --create image_stitcher       # 指定任务文件
@@ -20,6 +20,7 @@ import json
 import sys
 import time
 from pathlib import Path
+from typing import Any
 
 import httpx
 
@@ -99,21 +100,40 @@ def test_tools_status():
     print("  [PASS]")
 
 
-def test_select_tool(tool_id: str):
-    print(f"=== Select Tool (tool_id={tool_id!r}) ===")
-    resp = httpx.post(f"{BASE_URL}/select_tool", json={
-        "tool_id": tool_id,
-        "params": {}
-    }, timeout=30)
+def _run_tool(
+    tool_id: str, params: dict[str, Any], timeout: float = 120.0
+) -> tuple[bool, str | None, Any]:
+    """POST /run_tool。成功返回 (True, None, result);失败 (False, message, None)。"""
+    resp = httpx.post(
+        f"{BASE_URL}/run_tool",
+        json={"tool_id": tool_id, "params": params},
+        timeout=timeout,
+    )
     print(f"  Status : {resp.status_code}")
-    data = resp.json()
+    if resp.status_code != 200:
+        return False, f"HTTP {resp.status_code}: {resp.text[:300]}", None
+    try:
+        data = resp.json()
+    except Exception as e:
+        return False, f"Invalid JSON: {e}", None
+    if data.get("status") != "success":
+        return False, data.get("error") or str(data), None
+    result = data.get("result")
+    if isinstance(result, dict) and result.get("status") == "error":
+        return False, str(result.get("error", result)), None
+    return True, None, result
+
+
+def test_select_tool(tool_id: str):
+    print(f"=== Run Tool (tool_id={tool_id!r}) ===")
+    ok, err, result = _run_tool(tool_id, {}, timeout=30)
     print(f"  Result :")
-    print(f"    status: {data.get('status')}")
-    if data.get("error"):
-        print(f"    error : {data['error']}")
-    else:
-        result_str = json.dumps(data.get("result"), ensure_ascii=False, indent=6)
-        print(f"    result: {result_str[:500]}")
+    if not ok:
+        print(f"    error : {err}")
+        print("  [FAIL]")
+        return
+    result_str = json.dumps(result, ensure_ascii=False, indent=6)
+    print(f"    body: {result_str[:500]}")
     print("  [PASS]")
 
 
@@ -139,32 +159,34 @@ def test_stitch_images():
 
     print(f"  Calling image_stitcher (grid, 2 columns)...")
     try:
-        resp = httpx.post(f"{BASE_URL}/select_tool", json={
-            "tool_id": "image_stitcher",
-            "params": {
+        ok, err, result = _run_tool(
+            "image_stitcher",
+            {
                 "images": images_b64,
                 "direction": "grid",
                 "columns": 2,
                 "spacing": 10,
                 "background_color": "#FFFFFF",
-            }
-        }, timeout=60)
-        print(f"  Status : {resp.status_code}")
-        data = resp.json()
-        if data["status"] == "success":
-            result = data["result"]
-            print(f"  Result :")
-            print(f"    width : {result['width']}")
-            print(f"    height: {result['height']}")
-            OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
-            output_path = OUTPUT_DIR / "stitched_result.png"
-            with open(output_path, "wb") as f:
-                f.write(base64.b64decode(result["image"]))
-            print(f"    saved : {output_path}")
-            print("  [PASS]")
-        else:
-            print(f"  ERROR : {data.get('error', 'unknown')}")
+            },
+            timeout=120.0,
+        )
+        if not ok:
+            print(f"  ERROR : {err}")
+            print("  [FAIL]")
+            return
+        if not isinstance(result, dict) or "image" not in result:
+            print(f"  ERROR : 缺少 image 字段: {result!r}")
             print("  [FAIL]")
+            return
+        print(f"  Result :")
+        print(f"    width : {result.get('width')}")
+        print(f"    height: {result.get('height')}")
+        OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
+        output_path = OUTPUT_DIR / "stitched_result.png"
+        with open(output_path, "wb") as f:
+            f.write(base64.b64decode(result["image"]))
+        print(f"    saved : {output_path}")
+        print("  [PASS]")
     except httpx.TimeoutException:
         print("  ERROR : Request timeout")
         print("  [FAIL]")