| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144 |
- """测试 nano_banana — Router 调用 Gemini 图模(HTTP generateContent)
- 前提:
- - data/registry.json + data/sources.json 已注册 tool_id=nano_banana
- - tools/local/nano_banana 已提供 POST /generate,且 .env 中配置 GEMINI_API_KEY
- 用法:
- 1. uv run python -m tool_agent
- 2. uv run python tests/test_nano_banana.py
- 模型切换(任选其一):
- - 不传 NANO_BANANA_MODEL:请求体不含 model,由工具侧默认(如 gemini-2.5-flash-image /
- 环境变量 GEMINI_IMAGE_MODEL)
- - 显式切换预览图模:
- NANO_BANANA_MODEL=gemini-3.1-flash-image-preview uv run python tests/test_nano_banana.py
- 环境变量:
- TOOL_AGENT_ROUTER_URL 默认 http://127.0.0.1:8001
- NANO_BANANA_TOOL_ID 默认 nano_banana
- NANO_BANANA_TEST_PROMPT 覆盖默认短提示词
- NANO_BANANA_MODEL 非空时作为 params["model"] 传给 /run_tool
- """
- import io
- import os
- import sys
- from typing import Any
- if sys.platform == "win32":
- _out = sys.stdout
- if isinstance(_out, io.TextIOWrapper):
- _out.reconfigure(encoding="utf-8")
- import httpx
- ROUTER_URL = os.environ.get("TOOL_AGENT_ROUTER_URL", "http://127.0.0.1:8001")
- TOOL_ID = os.environ.get("NANO_BANANA_TOOL_ID", "nano_banana")
- NANO_BANANA_MODEL = os.environ.get("NANO_BANANA_MODEL", "").strip()
- TEST_PROMPT = os.environ.get(
- "NANO_BANANA_TEST_PROMPT",
- "A minimal flat icon of a yellow banana on white background, no text",
- )
- def run_tool(params: dict[str, Any], timeout: float = 180.0) -> dict[str, Any]:
- resp = httpx.post(
- f"{ROUTER_URL}/run_tool",
- json={"tool_id": TOOL_ID, "params": params},
- timeout=timeout,
- )
- resp.raise_for_status()
- body = resp.json()
- if body.get("status") != "success":
- raise RuntimeError(body.get("error") or str(body))
- result = body.get("result")
- if isinstance(result, dict) and result.get("status") == "error":
- raise RuntimeError(result.get("error", str(result)))
- return result if isinstance(result, dict) else {}
- def _has_image_payload(data: dict[str, Any]) -> bool:
- if not data:
- return False
- if data.get("images"):
- return True
- if data.get("image") and isinstance(data["image"], str) and len(data["image"]) > 100:
- return True
- if data.get("image_base64"):
- return True
- cands = data.get("candidates")
- if isinstance(cands, list) and cands:
- parts = cands[0].get("content", {}).get("parts", [])
- for p in parts:
- if isinstance(p, dict) and (p.get("inlineData") or p.get("inline_data")):
- return True
- return False
- def main():
- print("=" * 50)
- print("测试 nano_banana(Gemini 图模,可切换 model)")
- print("=" * 50)
- print(f"ROUTER_URL: {ROUTER_URL}")
- print(f"tool_id: {TOOL_ID}")
- if NANO_BANANA_MODEL:
- print(f"model: {NANO_BANANA_MODEL}(经 params 传入)")
- else:
- print("model: (未传,使用工具默认 / GEMINI_IMAGE_MODEL)")
- try:
- r = httpx.get(f"{ROUTER_URL}/health", timeout=3)
- print(f"Router 状态: {r.json()}")
- except httpx.ConnectError:
- print(f"无法连接 Router ({ROUTER_URL}),请先: uv run python -m tool_agent")
- sys.exit(1)
- print("\n--- 校验工具已注册 ---")
- tr = httpx.get(f"{ROUTER_URL}/tools", timeout=30)
- tr.raise_for_status()
- tools = tr.json().get("tools", [])
- ids = {t["tool_id"] for t in tools}
- if TOOL_ID not in ids:
- print(f"错误: {TOOL_ID!r} 不在 GET /tools 中。当前示例: {sorted(ids)[:15]}...")
- sys.exit(1)
- meta = next(t for t in tools if t["tool_id"] == TOOL_ID)
- print(f" {TOOL_ID}: {meta.get('name', '')} (state={meta.get('state')})")
- props = (meta.get("input_schema") or {}).get("properties") or {}
- if "model" in props:
- print(" input_schema 已声明 model(注册与实现应对齐)")
- else:
- print(" 提示: input_schema 尚无 model 字段,注册表宜补充以便编排知晓可切换模型")
- params: dict[str, Any] = {"prompt": TEST_PROMPT}
- if NANO_BANANA_MODEL:
- params["model"] = NANO_BANANA_MODEL
- print("\n--- 调用生图 ---")
- print(f"prompt: {TEST_PROMPT[:80]}{'...' if len(TEST_PROMPT) > 80 else ''}")
- try:
- data = run_tool(params, timeout=180.0)
- except (RuntimeError, httpx.HTTPError) as e:
- print(f"错误: {e}")
- sys.exit(1)
- print(f"\n下游返回 keys: {list(data.keys())[:20]}")
- if rm := data.get("model"):
- print(f"下游报告 model: {rm}")
- if NANO_BANANA_MODEL and rm != NANO_BANANA_MODEL:
- print(
- f"警告: 请求 model={NANO_BANANA_MODEL!r} 与返回 model={rm!r} 不一致(若工具会规范化 ID 可忽略)"
- )
- if _has_image_payload(data):
- print("\n检测到图片相关字段,测试通过!")
- return
- print("\n未识别到常见图片字段(images / image / candidates[].inlineData 等)。")
- print(f"完整结果(截断): {str(data)[:800]}")
- sys.exit(1)
- if __name__ == "__main__":
- main()
|