test_nano_banana.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. """测试 nano_banana — Router 调用 Gemini 图模(HTTP generateContent)
  2. 前提:
  3. - data/registry.json + data/sources.json 已注册 tool_id=nano_banana
  4. - tools/local/nano_banana 已提供 POST /generate,且 .env 中配置 GEMINI_API_KEY
  5. 用法:
  6. 1. uv run python -m tool_agent
  7. 2. uv run python tests/test_nano_banana.py
  8. 模型切换(任选其一):
  9. - 不传 NANO_BANANA_MODEL:请求体不含 model,由工具侧默认(如 gemini-2.5-flash-image /
  10. 环境变量 GEMINI_IMAGE_MODEL)
  11. - 显式切换预览图模:
  12. NANO_BANANA_MODEL=gemini-3.1-flash-image-preview uv run python tests/test_nano_banana.py
  13. 环境变量:
  14. TOOL_AGENT_ROUTER_URL 默认 http://127.0.0.1:8001
  15. NANO_BANANA_TOOL_ID 默认 nano_banana
  16. NANO_BANANA_TEST_PROMPT 覆盖默认短提示词
  17. NANO_BANANA_MODEL 非空时作为 params["model"] 传给 /run_tool
  18. """
  19. import io
  20. import os
  21. import sys
  22. from typing import Any
  23. if sys.platform == "win32":
  24. _out = sys.stdout
  25. if isinstance(_out, io.TextIOWrapper):
  26. _out.reconfigure(encoding="utf-8")
  27. import httpx
  28. ROUTER_URL = os.environ.get("TOOL_AGENT_ROUTER_URL", "http://127.0.0.1:8001")
  29. TOOL_ID = os.environ.get("NANO_BANANA_TOOL_ID", "nano_banana")
  30. NANO_BANANA_MODEL = os.environ.get("NANO_BANANA_MODEL", "").strip()
  31. TEST_PROMPT = os.environ.get(
  32. "NANO_BANANA_TEST_PROMPT",
  33. "A minimal flat icon of a yellow banana on white background, no text",
  34. )
  35. def run_tool(params: dict[str, Any], timeout: float = 180.0) -> dict[str, Any]:
  36. resp = httpx.post(
  37. f"{ROUTER_URL}/run_tool",
  38. json={"tool_id": TOOL_ID, "params": params},
  39. timeout=timeout,
  40. )
  41. resp.raise_for_status()
  42. body = resp.json()
  43. if body.get("status") != "success":
  44. raise RuntimeError(body.get("error") or str(body))
  45. result = body.get("result")
  46. if isinstance(result, dict) and result.get("status") == "error":
  47. raise RuntimeError(result.get("error", str(result)))
  48. return result if isinstance(result, dict) else {}
  49. def _has_image_payload(data: dict[str, Any]) -> bool:
  50. if not data:
  51. return False
  52. if data.get("images"):
  53. return True
  54. if data.get("image") and isinstance(data["image"], str) and len(data["image"]) > 100:
  55. return True
  56. if data.get("image_base64"):
  57. return True
  58. cands = data.get("candidates")
  59. if isinstance(cands, list) and cands:
  60. parts = cands[0].get("content", {}).get("parts", [])
  61. for p in parts:
  62. if isinstance(p, dict) and (p.get("inlineData") or p.get("inline_data")):
  63. return True
  64. return False
  65. def main():
  66. print("=" * 50)
  67. print("测试 nano_banana(Gemini 图模,可切换 model)")
  68. print("=" * 50)
  69. print(f"ROUTER_URL: {ROUTER_URL}")
  70. print(f"tool_id: {TOOL_ID}")
  71. if NANO_BANANA_MODEL:
  72. print(f"model: {NANO_BANANA_MODEL}(经 params 传入)")
  73. else:
  74. print("model: (未传,使用工具默认 / GEMINI_IMAGE_MODEL)")
  75. try:
  76. r = httpx.get(f"{ROUTER_URL}/health", timeout=3)
  77. print(f"Router 状态: {r.json()}")
  78. except httpx.ConnectError:
  79. print(f"无法连接 Router ({ROUTER_URL}),请先: uv run python -m tool_agent")
  80. sys.exit(1)
  81. print("\n--- 校验工具已注册 ---")
  82. tr = httpx.get(f"{ROUTER_URL}/tools", timeout=30)
  83. tr.raise_for_status()
  84. tools = tr.json().get("tools", [])
  85. ids = {t["tool_id"] for t in tools}
  86. if TOOL_ID not in ids:
  87. print(f"错误: {TOOL_ID!r} 不在 GET /tools 中。当前示例: {sorted(ids)[:15]}...")
  88. sys.exit(1)
  89. meta = next(t for t in tools if t["tool_id"] == TOOL_ID)
  90. print(f" {TOOL_ID}: {meta.get('name', '')} (state={meta.get('state')})")
  91. props = (meta.get("input_schema") or {}).get("properties") or {}
  92. if "model" in props:
  93. print(" input_schema 已声明 model(注册与实现应对齐)")
  94. else:
  95. print(" 提示: input_schema 尚无 model 字段,注册表宜补充以便编排知晓可切换模型")
  96. params: dict[str, Any] = {"prompt": TEST_PROMPT}
  97. if NANO_BANANA_MODEL:
  98. params["model"] = NANO_BANANA_MODEL
  99. print("\n--- 调用生图 ---")
  100. print(f"prompt: {TEST_PROMPT[:80]}{'...' if len(TEST_PROMPT) > 80 else ''}")
  101. try:
  102. data = run_tool(params, timeout=180.0)
  103. except (RuntimeError, httpx.HTTPError) as e:
  104. print(f"错误: {e}")
  105. sys.exit(1)
  106. print(f"\n下游返回 keys: {list(data.keys())[:20]}")
  107. if rm := data.get("model"):
  108. print(f"下游报告 model: {rm}")
  109. if NANO_BANANA_MODEL and rm != NANO_BANANA_MODEL:
  110. print(
  111. f"警告: 请求 model={NANO_BANANA_MODEL!r} 与返回 model={rm!r} 不一致(若工具会规范化 ID 可忽略)"
  112. )
  113. if _has_image_payload(data):
  114. print("\n检测到图片相关字段,测试通过!")
  115. return
  116. print("\n未识别到常见图片字段(images / image / candidates[].inlineData 等)。")
  117. print(f"完整结果(截断): {str(data)[:800]}")
  118. sys.exit(1)
  119. if __name__ == "__main__":
  120. main()