""" ToolHub - 远程工具库集成模块 将 http://43.106.118.91:8001 的工具库 API 包装为 Agent 可调用的工具。 提供三个工具: 1. toolhub_health - 健康检查 2. toolhub_search - 搜索/发现远程工具(GET /tools) 3. toolhub_call - 调用远程工具(POST /run_tool) 图片参数统一使用本地文件路径: - 输入:params 中的 image/image_url 等参数直接传本地路径,内部自动上传 - 输出:生成的图片自动保存到 outputs/ 目录,返回本地路径 实际 API 端点(通过 /openapi.json 确认): GET /health → 健康检查 GET /tools → 列出所有工具(含分组、参数 schema) POST /run_tool → 调用工具 {"tool_id": str, "params": dict} POST /chat → 对话接口(不在此封装) CLI 用法: python -m agent.tools.builtin.toolhub health python -m agent.tools.builtin.toolhub search --keyword=image python -m agent.tools.builtin.toolhub call --tool_id=flux_gen --params='{"prompt":"a cat"}' """ import base64 import contextvars import json import logging import mimetypes import os import time from pathlib import Path from typing import Any, Dict, List, Optional import httpx from agent.tools import tool, ToolResult logger = logging.getLogger(__name__) # ── 配置 ───────────────────────────────────────────── TOOLHUB_BASE_URL = "http://43.106.118.91:8001" DEFAULT_TIMEOUT = 30.0 CALL_TIMEOUT = 600.0 # 图像生成类工具耗时较长,云端机器启动可能需要数分钟 # OSS 上传配置 OSS_BUCKET_NAME = "aigc-admin" OSS_BUCKET_PATH = "toolhub_images" # 输出目录(相对于项目根目录) OUTPUT_BASE_DIR = Path("outputs") # trace_id 上下文变量,由 runner 在执行工具前设置 _trace_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("toolhub_trace_id", default="") def set_trace_context(trace_id: str): """由 runner 调用,设置当前 trace_id 供图片保存使用""" _trace_id_var.set(trace_id) def _get_output_dir(tool_id: str) -> Path: """获取图片输出目录:outputs/{trace_id}/,无 trace_id 时用时间戳""" trace_id = _trace_id_var.get("") if trace_id: # trace_id 可能含 @ 等特殊字符,取前段作为目录名 safe_id = trace_id.split("@")[0][:12] if "@" in trace_id else trace_id[:12] out_dir = OUTPUT_BASE_DIR / safe_id else: out_dir = OUTPUT_BASE_DIR / f"no_trace_{int(time.time())}" out_dir.mkdir(parents=True, exist_ok=True) return out_dir # ── 图片处理辅助 ───────────────────────────────────── async def _upload_to_oss(local_path: str) -> Optional[str]: """上传本地文件到 OSS,返回 CDN URL""" try: from cyber_sdk.ali_oss import upload_localfile import os safe_path = os.path.abspath(local_path).replace("\\", "/") result = await upload_localfile( file_path=safe_path, bucket_path=OSS_BUCKET_PATH, bucket_name=OSS_BUCKET_NAME, ) oss_key = result.get("oss_object_key") if oss_key: cdn_url = f"https://res.cybertogether.net/{oss_key}" logger.info(f"[ToolHub] 图片已上传 OSS: {cdn_url}") return cdn_url except Exception as e: logger.warning(f"[ToolHub] OSS 上传失败: {e}") return None async def _process_images(raw_images: List[str], tool_id: str) -> tuple: """ 统一处理工具返回的图片列表。 对每张图片:下载(如需) → 保存本地 → 上传 OSS → 拿到 CDN URL Returns: (images_for_llm, cdn_urls, saved_paths) - images_for_llm: 给 runner 的图片列表(base64 格式,用于 LLM 多模态查看) - cdn_urls: 永久 CDN URL 列表 - saved_paths: 本地文件路径列表 """ images_for_llm = [] cdn_urls = [] saved_paths = [] original_urls = [] out_dir = _get_output_dir(tool_id) for idx, img in enumerate(raw_images): if not isinstance(img, str) or len(img) <= 100: continue img_bytes = None media_type = "image/png" if img.startswith(("http://", "https://")): original_urls.append(img) try: async with httpx.AsyncClient(timeout=60, trust_env=False) as dl: img_resp = await dl.get(img) img_resp.raise_for_status() ct = img_resp.headers.get("content-type", "image/png").split(";")[0].strip() if not ct.startswith("image/"): ct = mimetypes.guess_type(img.split("?")[0])[0] or "image/png" media_type = ct img_bytes = img_resp.content except Exception as e: logger.warning(f"[ToolHub] 图片下载失败: {e}") continue elif img.startswith("data:"): header, b64 = img.split(",", 1) media_type = header.split(";")[0].replace("data:", "") img_bytes = base64.b64decode(b64) else: # raw base64 img_bytes = base64.b64decode(img) if not img_bytes: continue # 1. 保存本地(用时间戳区分多次调用) ts = int(time.time() * 1000) ext = {"image/png": ".png", "image/jpeg": ".jpg", "image/webp": ".webp"}.get(media_type, ".png") save_path = out_dir / f"{tool_id}_{ts}_{idx}{ext}" save_path.write_bytes(img_bytes) saved_paths.append(str(save_path)) # 2. 上传 OSS 拿 CDN URL cdn_url = await _upload_to_oss(str(save_path)) if cdn_url: cdn_urls.append(cdn_url) # 3. base64 给 LLM 多模态查看 b64_data = base64.b64encode(img_bytes).decode() images_for_llm.append({"type": "base64", "media_type": media_type, "data": b64_data}) return images_for_llm, cdn_urls, saved_paths _SINGLE_IMAGE_PARAMS = ("image", "image_url", "mask_image", "pose_image", "reference_image") _ARRAY_IMAGE_PARAMS = ("images", "image_urls", "reference_images") async def _maybe_upload_local(val: str) -> Optional[str]: """如果 val 是存在的本地文件路径,上传 OSS 并返回 CDN URL;否则返回 None。""" if not isinstance(val, str): return None if val.startswith(("http://", "https://", "data:")): return None try: p = Path(val) if p.exists() and p.is_file(): return await _upload_to_oss(str(p.resolve())) except Exception as e: logger.warning(f"[ToolHub] 本地路径处理失败 {val}: {e}") return None async def _preprocess_params(params: Dict[str, Any]) -> Dict[str, Any]: """ 预处理工具参数:检测本地文件路径,自动上传到 OSS 并替换为 CDN URL。 支持的单值参数:image, image_url, mask_image, pose_image, reference_image 支持的数组参数:images, image_urls, reference_images 设计要点:远程工具服务的 cwd 和调用方不一样,相对路径在服务器上会找不到文件。 所以必须在客户端就把本地路径转成 CDN URL,不能期望服务器侧有 fallback。 """ if not params: return params processed = params.copy() # 单值图片参数 for key in _SINGLE_IMAGE_PARAMS: if key in processed and isinstance(processed[key], str): val = processed[key] if val.startswith(("http://", "https://", "data:")): continue cdn_url = await _maybe_upload_local(val) if cdn_url: processed[key] = cdn_url logger.info(f"[ToolHub] {key} 本地路径已替换为 CDN: {cdn_url}") elif not os.path.isfile(val): # 既不是远程 URL 也不是已存在的本地文件,直接报错比让远程服务抛神秘的 base64 错误强 logger.warning(f"[ToolHub] {key}={val!r} 既不是 URL 也不是存在的本地文件") # 数组型图片参数 for array_key in _ARRAY_IMAGE_PARAMS: if array_key not in processed or not isinstance(processed[array_key], list): continue new_list = [] for idx, item in enumerate(processed[array_key]): if not isinstance(item, str): new_list.append(item) continue if item.startswith(("http://", "https://", "data:")): new_list.append(item) continue cdn_url = await _maybe_upload_local(item) if cdn_url: new_list.append(cdn_url) logger.info(f"[ToolHub] {array_key}[{idx}] 本地路径已替换为 CDN: {cdn_url}") else: new_list.append(item) if not os.path.isfile(item): logger.warning( f"[ToolHub] {array_key}[{idx}]={item!r} 既不是 URL 也不是存在的本地文件" ) processed[array_key] = new_list return processed # ── 工具实现 ────────────────────────────────────────── @tool( display={ "zh": {"name": "ToolHub 健康检查", "params": {}}, "en": {"name": "ToolHub Health Check", "params": {}}, }, groups=["toolhub"], ) async def toolhub_health() -> ToolResult: """检查 ToolHub 远程工具库服务是否可用 检查 ToolHub 服务的健康状态,确认服务是否正常运行。 建议在调用其他 toolhub 工具之前先检查。 Returns: ToolResult 包含服务健康状态信息 """ try: async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT, trust_env=False) as client: resp = await client.get(f"{TOOLHUB_BASE_URL}/health") resp.raise_for_status() data = resp.json() return ToolResult( title="ToolHub 健康检查", output=json.dumps(data, ensure_ascii=False, indent=2), long_term_memory=f"ToolHub service at {TOOLHUB_BASE_URL} is healthy.", ) except httpx.ConnectError: return ToolResult( title="ToolHub 健康检查", output="", error=f"无法连接到 ToolHub 服务 {TOOLHUB_BASE_URL},请确认服务已启动。", ) except Exception as e: err_msg = f"{type(e).__name__}: {e}" if str(e) else type(e).__name__ return ToolResult( title="ToolHub 健康检查", output="", error=err_msg, ) @tool( display={ "zh": {"name": "搜索 ToolHub 工具", "params": {"keyword": "搜索关键词"}}, "en": {"name": "Search ToolHub", "params": {"keyword": "Search keyword"}}, }, groups=["toolhub"], ) async def toolhub_search(keyword: Optional[str] = None) -> ToolResult: """搜索 ToolHub 远程工具库中可用的工具 从 ToolHub 工具库中获取可用工具列表,返回每个工具的完整信息,包括: tool_id、名称、分类、状态、参数列表(含类型、是否必填、默认值)、输出 schema、 分组信息(如 RunComfy 生命周期组)等。 调用 toolhub_call 之前,应先使用此工具了解目标工具的 tool_id 和所需参数。 不填 keyword 则返回所有工具。 Args: keyword: 搜索关键词,用于过滤工具名称或描述(客户端过滤);为空则返回所有工具 Returns: ToolResult 包含匹配的工具列表及其参数说明 """ try: async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT, trust_env=False) as client: resp = await client.get(f"{TOOLHUB_BASE_URL}/tools") resp.raise_for_status() data = resp.json() tools = data.get("tools", []) groups = data.get("groups", []) # 客户端关键词过滤:三层匹配策略 # 1) 原始子串匹配(最严格,但会被分隔符切断:nanobanana vs nano_banana) # 2) 归一化子串匹配(去掉 _ - 空格 .,解决分隔符问题) # 3) 分词交集匹配(keyword 拆成 token,任意 token 命中即保留,解决多词查询) if keyword: def _normalize(s: str) -> str: """去掉分隔符和空白,全小写""" return "".join(c for c in s.lower() if c.isalnum()) def _tokenize(s: str) -> set: """按分隔符拆成 token 集合""" import re return {t for t in re.split(r"[\s_\-.,/]+", s.lower()) if t} kw_raw = keyword.lower() kw_norm = _normalize(keyword) kw_tokens = _tokenize(keyword) def _matches(t: dict) -> bool: fields = [ t.get("name", ""), t.get("description", ""), t.get("tool_id", ""), t.get("category", ""), ] combined = " ".join(fields).lower() # 原始子串 if kw_raw in combined: return True # 归一化子串(容忍分隔符差异) if kw_norm and kw_norm in _normalize(combined): return True # token 交集(多词关键词的 OR 匹配) if kw_tokens: field_tokens = _tokenize(combined) if kw_tokens & field_tokens: return True return False tools = [t for t in tools if _matches(t)] total = len(tools) # 构建给 LLM 的结构化摘要 summaries = [] for t in tools: input_props = t.get("input_schema", {}).get("properties", {}) required_fields = t.get("input_schema", {}).get("required", []) params_desc = [] for name, info in input_props.items(): req = "必填" if name in required_fields else "可选" desc = info.get("description", "") default_str = f", 默认={info['default']}" if info.get("default") is not None else "" enum_str = f", 可选值={info['enum']}" if info.get("enum") else "" params_desc.append( f" - {name} ({info.get('type','any')}, {req}): {desc}{default_str}{enum_str}" ) group_str = "" if t.get("group_ids"): group_str = f"\n 所属分组: {', '.join(t['group_ids'])}" tool_block = ( f"[{t['tool_id']}] {t['name']}\n" f" 状态: {t['state']} | 运行时: {t['backend_runtime']} | 分类: {t.get('category','')}" f"{group_str}\n" f" 描述: {t.get('description', '')}" ) if params_desc: tool_block += "\n 参数:\n" + "\n".join(params_desc) else: tool_block += "\n 参数: 无" summaries.append(tool_block) # 分组使用说明:仅显示搜出的工具实际所属的分组,避免噪音 relevant_group_ids = set() for t in tools: for gid in t.get("group_ids", []) or []: relevant_group_ids.add(gid) group_summary = [] for g in groups: if g["group_id"] not in relevant_group_ids: continue group_summary.append( f"[组: {g['group_id']}] {g['name']}\n" f" 调用顺序: {' → '.join(g.get('usage_order', []))}\n" f" 说明: {g.get('usage_example', '')}" ) output_parts = [f"共找到 {total} 个工具({'关键词: ' + keyword if keyword else '全量'}):\n"] output_parts.append("\n\n".join(summaries)) if group_summary: output_parts.append("\n\n=== 工具分组(有顺序依赖)===\n" + "\n\n".join(group_summary)) return ToolResult( title=f"ToolHub 搜索{f': {keyword}' if keyword else '(全量)'}", output="\n".join(output_parts), long_term_memory=( f"ToolHub 共 {total} 个工具: " + ", ".join(t["tool_id"] for t in tools[:15]) + ("..." if total > 15 else "") ), ) except (httpx.ReadTimeout, httpx.ConnectTimeout, httpx.WriteTimeout, httpx.PoolTimeout) as e: return ToolResult( title="ToolHub /tools 超时", output="", error=f"ToolHub 的 /tools 接口在 {DEFAULT_TIMEOUT:.0f}s 内未响应({type(e).__name__})。" f"服务器可能在检测各工具状态导致列表慢,请稍后重试或联系维护者。", ) except httpx.ConnectError as e: return ToolResult( title="ToolHub 连接失败", output="", error=f"无法连接到 ToolHub 服务 {TOOLHUB_BASE_URL}:{type(e).__name__}: {e}", ) except Exception as e: # 注意 httpx 的部分异常 str(e) 是空的,必须带 type name err_msg = f"{type(e).__name__}: {e}" if str(e) else type(e).__name__ return ToolResult( title="搜索 ToolHub 工具失败", output="", error=err_msg, ) @tool( display={ "zh": { "name": "调用 ToolHub 工具", "params": {"tool_id": "工具ID", "params": "工具参数"}, }, "en": { "name": "Call ToolHub Tool", "params": {"tool_id": "Tool ID", "params": "Tool parameters"}, }, }, groups=["toolhub"], ) async def toolhub_call( tool_id: str, params: Optional[Dict[str, Any]] = None, ) -> ToolResult: """调用 ToolHub 远程工具库中的指定工具 通过 tool_id 调用 ToolHub 工具库中的某个工具,传入该工具所需的参数。 不同工具的参数不同,请先用 toolhub_search 查询目标工具的参数说明。 图片参数(image、image_url、mask_image、pose_image、images)直接传本地文件路径即可, 系统会自动上传。生成的图片会自动保存到本地 outputs/ 目录,返回结果中的 saved_files 字段包含本地文件路径。 注意:部分工具为异步生命周期(如 RunComfy、即梦、FLUX),需要按分组顺序 依次调用多个工具(如先 launch → 再 executor → 再 stop)。 Args: tool_id: 要调用的工具 ID(从 toolhub_search 获取) params: 工具参数字典,键值对根据目标工具的参数定义决定。 图片参数可直接使用本地文件路径(如 "/path/to/image.png")。 Returns: ToolResult 包含工具执行结果,图片结果通过 saved_files 返回本地路径 """ try: # 预处理参数:本地文件路径自动上传成 CDN URL params = await _preprocess_params(params or {}) payload = { "tool_id": tool_id, "params": params, } async with httpx.AsyncClient(timeout=CALL_TIMEOUT, trust_env=False) as client: resp = await client.post( f"{TOOLHUB_BASE_URL}/run_tool", json=payload ) resp.raise_for_status() data = resp.json() status = data.get("status") if status == "success": result = data.get("result", {}) result_str = json.dumps(result, ensure_ascii=False, indent=2) # 提取图片并统一处理(下载 → 保存本地 → 上传 OSS → CDN URL) images = [] if isinstance(result, dict): # 收集所有图片(单张 image 字段 + images 列表字段) raw_images = [] has_single_image = False has_images_list = False if result.get("image") and isinstance(result["image"], str): raw_images.append(result["image"]) has_single_image = True if result.get("images") and isinstance(result["images"], list): raw_images.extend(result["images"]) has_images_list = True if raw_images: images, cdn_urls, saved_paths = await _process_images(raw_images, tool_id) # 构建文本输出(去掉原始图片数据,以本地路径为主) result_display = {k: v for k, v in result.items() if k not in ("image", "images")} result_display["image_count"] = len(images) if saved_paths: result_display["saved_files"] = saved_paths if cdn_urls: result_display["cdn_urls"] = cdn_urls result_str = json.dumps(result_display, ensure_ascii=False, indent=2) return ToolResult( title=f"ToolHub [{tool_id}] 执行成功", output=result_str, long_term_memory=f"Called ToolHub tool '{tool_id}' → success", images=images, ) else: error_msg = data.get("error", "未知错误") return ToolResult( title=f"ToolHub [{tool_id}] 执行失败", output=json.dumps(data, ensure_ascii=False, indent=2), error=error_msg, ) except httpx.TimeoutException as e: return ToolResult( title=f"ToolHub [{tool_id}] 调用超时", output="", error=f"调用工具 {tool_id} 超时({CALL_TIMEOUT:.0f}s,{type(e).__name__})," f"图像生成类工具可能需要更长时间。", ) except Exception as e: err_msg = f"{type(e).__name__}: {e}" if str(e) else type(e).__name__ return ToolResult( title=f"ToolHub [{tool_id}] 调用失败", output="", error=err_msg, ) # 注意:image_uploader 和 image_downloader 不再注册为 Agent 工具。 # toolhub_call 已内置完整的图片管线(输入自动上传,输出自动下载保存), # 无需单独暴露上传/下载工具。以下函数保留供内部或 CLI 使用。 async def image_uploader(local_path: str) -> ToolResult: """将本地图片上传到 OSS,返回可用的 CDN URL(内部工具,不注册给 Agent)""" import os from pathlib import Path p = Path(local_path) if not p.exists(): return ToolResult( title="图片上传失败", output="", error=f"文件不存在: {local_path}", ) if not p.is_file(): return ToolResult( title="图片上传失败", output="", error=f"路径不是文件: {local_path}", ) cdn_url = await _upload_to_oss(str(p.resolve())) if cdn_url: result = { "local_path": str(p.resolve()), "cdn_url": cdn_url, "file_size": os.path.getsize(p), } return ToolResult( title="图片上传成功", output=json.dumps(result, ensure_ascii=False, indent=2), long_term_memory=f"Uploaded {local_path} → {cdn_url}", ) else: return ToolResult( title="图片上传失败", output="", error=f"OSS 上传失败,请检查文件路径和网络连接: {local_path}", ) async def image_downloader(url: str, save_path: str = "") -> ToolResult: """下载网络图片到本地文件(内部工具,不注册给 Agent)""" import os from pathlib import Path from urllib.parse import urlparse, unquote if not url.startswith(("http://", "https://")): return ToolResult( title="图片下载失败", output="", error=f"无效的 URL(必须以 http:// 或 https:// 开头): {url}", ) # 自动生成保存路径 if not save_path: out_dir = _get_output_dir("download") # 从 URL 提取文件名 url_path = urlparse(url).path filename = Path(unquote(url_path)).name if url_path else "" if not filename or not any(filename.lower().endswith(ext) for ext in (".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp")): filename = f"download_{int(time.time())}.png" save_path = str(out_dir / filename) # 确保目录存在 p = Path(save_path) p.parent.mkdir(parents=True, exist_ok=True) try: async with httpx.AsyncClient(timeout=60.0, follow_redirects=True, trust_env=False) as client: resp = await client.get(url) resp.raise_for_status() p.write_bytes(resp.content) file_size = os.path.getsize(p) result = { "save_path": str(p.resolve()), "file_size": file_size, "source_url": url, } return ToolResult( title="图片下载成功", output=json.dumps(result, ensure_ascii=False, indent=2), long_term_memory=f"Downloaded {url} → {save_path}", ) except httpx.HTTPStatusError as e: return ToolResult( title="图片下载失败", output="", error=f"HTTP 错误 {e.response.status_code}: {url}", ) except Exception as e: return ToolResult( title="图片下载失败", output="", error=f"下载失败: {e}", ) if __name__ == "__main__": import sys COMMANDS = { "health": toolhub_health, "search": toolhub_search, "call": toolhub_call, } def _parse_args(argv): kwargs = {} for arg in argv: if arg.startswith("--") and "=" in arg: k, v = arg.split("=", 1) k = k.lstrip("-").replace("-", "_") try: v = json.loads(v) except (json.JSONDecodeError, ValueError): pass kwargs[k] = v return kwargs if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help"): print(f"用法: python {sys.argv[0]} [--key=value ...]") print(f"可用命令: {', '.join(COMMANDS.keys())}") sys.exit(0) cmd = sys.argv[1] if cmd not in COMMANDS: print(f"未知命令: {cmd},可用: {', '.join(COMMANDS.keys())}") sys.exit(1) import asyncio import uuid import os kwargs = _parse_args(sys.argv[2:]) # trace_id:CLI 参数 > 环境变量 > 自动生成(用于图片输出目录) trace_id = kwargs.pop("trace_id", None) or os.getenv("TRACE_ID") or f"cli-{uuid.uuid4().hex[:8]}" set_trace_context(trace_id) result = asyncio.run(COMMANDS[cmd](**kwargs)) # 修复双重 JSON 编码:如果 output 已经是一段 JSON 字符串(toolhub_call 内部 # 把 result dict 做过 json.dumps),解析回原生 dict 再嵌入 CLI 的最终 JSON, # 避免调用方拿到"output 字段是被字符串化的 JSON"这种反人类形式。 output_value = result.output if isinstance(output_value, str): stripped = output_value.lstrip() if stripped.startswith("{") or stripped.startswith("["): try: output_value = json.loads(output_value) except (json.JSONDecodeError, ValueError): pass # 非 JSON 文本,保持原样 out = {"trace_id": trace_id, "output": output_value} if result.error: out["error"] = result.error if result.metadata: out["metadata"] = result.metadata print(json.dumps(out, ensure_ascii=False, indent=2))