| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617 |
- """
- ToolHub - 远程工具库集成模块
- 将 http://43.106.118.91:8001 的工具库 API 包装为 Agent 可调用的工具。
- 提供三个工具:
- 1. toolhub_health - 健康检查
- 2. toolhub_search - 搜索/发现远程工具(GET /tools)
- 3. toolhub_call - 调用远程工具(POST /run_tool)
- 实际 API 端点(通过 /openapi.json 确认):
- GET /health → 健康检查
- GET /tools → 列出所有工具(含分组、参数 schema)
- POST /run_tool → 调用工具 {"tool_id": str, "params": dict}
- POST /chat → 对话接口(不在此封装)
- """
- import base64
- import contextvars
- import json
- import logging
- import mimetypes
- import time
- from pathlib import Path
- from typing import Any, Dict, List, Optional
- import httpx
- from agent.tools import tool, ToolResult
- logger = logging.getLogger(__name__)
- # ── 配置 ─────────────────────────────────────────────
- TOOLHUB_BASE_URL = "http://43.106.118.91:8001"
- DEFAULT_TIMEOUT = 30.0
- CALL_TIMEOUT = 600.0 # 图像生成类工具耗时较长,云端机器启动可能需要数分钟
- # OSS 上传配置
- OSS_BUCKET_NAME = "aigc-admin"
- OSS_BUCKET_PATH = "toolhub_images"
- # 输出目录(相对于项目根目录)
- OUTPUT_BASE_DIR = Path("outputs")
- # trace_id 上下文变量,由 runner 在执行工具前设置
- _trace_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("toolhub_trace_id", default="")
- def set_trace_context(trace_id: str):
- """由 runner 调用,设置当前 trace_id 供图片保存使用"""
- _trace_id_var.set(trace_id)
- def _get_output_dir(tool_id: str) -> Path:
- """获取图片输出目录:outputs/{trace_id}/,无 trace_id 时用时间戳"""
- trace_id = _trace_id_var.get("")
- if trace_id:
- # trace_id 可能含 @ 等特殊字符,取前段作为目录名
- safe_id = trace_id.split("@")[0][:12] if "@" in trace_id else trace_id[:12]
- out_dir = OUTPUT_BASE_DIR / safe_id
- else:
- out_dir = OUTPUT_BASE_DIR / f"no_trace_{int(time.time())}"
- out_dir.mkdir(parents=True, exist_ok=True)
- return out_dir
- # ── 图片处理辅助 ─────────────────────────────────────
- async def _upload_to_oss(local_path: str) -> Optional[str]:
- """上传本地文件到 OSS,返回 CDN URL"""
- try:
- from cyber_sdk.ali_oss import upload_localfile
- import os
- safe_path = os.path.abspath(local_path).replace("\\", "/")
- result = await upload_localfile(
- file_path=safe_path,
- bucket_path=OSS_BUCKET_PATH,
- bucket_name=OSS_BUCKET_NAME,
- )
- oss_key = result.get("oss_object_key")
- if oss_key:
- cdn_url = f"https://res.cybertogether.net/{oss_key}"
- logger.info(f"[ToolHub] 图片已上传 OSS: {cdn_url}")
- return cdn_url
- except Exception as e:
- logger.warning(f"[ToolHub] OSS 上传失败: {e}")
- return None
- async def _process_images(raw_images: List[str], tool_id: str) -> tuple:
- """
- 统一处理工具返回的图片列表。
- 对每张图片:下载(如需) → 保存本地 → 上传 OSS → 拿到 CDN URL
- Returns:
- (images_for_llm, cdn_urls, saved_paths)
- - images_for_llm: 给 runner 的图片列表(base64 格式,用于 LLM 多模态查看)
- - cdn_urls: 永久 CDN URL 列表
- - saved_paths: 本地文件路径列表
- """
- images_for_llm = []
- cdn_urls = []
- saved_paths = []
- original_urls = []
- out_dir = _get_output_dir(tool_id)
- for idx, img in enumerate(raw_images):
- if not isinstance(img, str) or len(img) <= 100:
- continue
- img_bytes = None
- media_type = "image/png"
- if img.startswith(("http://", "https://")):
- original_urls.append(img)
- try:
- async with httpx.AsyncClient(timeout=60, trust_env=False) as dl:
- img_resp = await dl.get(img)
- img_resp.raise_for_status()
- ct = img_resp.headers.get("content-type", "image/png").split(";")[0].strip()
- if not ct.startswith("image/"):
- ct = mimetypes.guess_type(img.split("?")[0])[0] or "image/png"
- media_type = ct
- img_bytes = img_resp.content
- except Exception as e:
- logger.warning(f"[ToolHub] 图片下载失败: {e}")
- continue
- elif img.startswith("data:"):
- header, b64 = img.split(",", 1)
- media_type = header.split(";")[0].replace("data:", "")
- img_bytes = base64.b64decode(b64)
- else:
- # raw base64
- img_bytes = base64.b64decode(img)
- if not img_bytes:
- continue
- # 1. 保存本地(用时间戳区分多次调用)
- ts = int(time.time() * 1000)
- ext = {"image/png": ".png", "image/jpeg": ".jpg", "image/webp": ".webp"}.get(media_type, ".png")
- save_path = out_dir / f"{tool_id}_{ts}_{idx}{ext}"
- save_path.write_bytes(img_bytes)
- saved_paths.append(str(save_path))
- # 2. 上传 OSS 拿 CDN URL
- cdn_url = await _upload_to_oss(str(save_path))
- if cdn_url:
- cdn_urls.append(cdn_url)
- # 3. base64 给 LLM 多模态查看
- b64_data = base64.b64encode(img_bytes).decode()
- images_for_llm.append({"type": "base64", "media_type": media_type, "data": b64_data})
- return images_for_llm, cdn_urls, saved_paths
- async def _preprocess_params(params: Dict[str, Any]) -> Dict[str, Any]:
- """
- 预处理工具参数:检测本地文件路径,自动上传到 OSS 并替换为 CDN URL。
- 支持的参数名:image, image_url, mask_image, pose_image, images (数组)
- """
- if not params:
- return params
- processed = params.copy()
- # 单个图片参数
- for key in ("image", "image_url", "mask_image", "pose_image"):
- if key in processed and isinstance(processed[key], str):
- val = processed[key]
- # 检测是否为本地路径(不是 http/https/data: 开头)
- if not val.startswith(("http://", "https://", "data:")):
- # 尝试读取本地文件
- try:
- from pathlib import Path
- p = Path(val)
- if p.exists() and p.is_file():
- logger.info(f"[ToolHub] 检测到本地文件 {key}={val},上传到 OSS...")
- cdn_url = await _upload_to_oss(str(p.resolve()))
- if cdn_url:
- processed[key] = cdn_url
- logger.info(f"[ToolHub] {key} 已替换为 CDN URL: {cdn_url}")
- else:
- logger.warning(f"[ToolHub] {key} 上传失败,保持原路径")
- except Exception as e:
- logger.warning(f"[ToolHub] {key} 路径处理失败: {e}")
- # images 数组参数
- if "images" in processed and isinstance(processed["images"], list):
- new_images = []
- for idx, img in enumerate(processed["images"]):
- if isinstance(img, str) and not img.startswith(("http://", "https://", "data:")):
- try:
- from pathlib import Path
- p = Path(img)
- if p.exists() and p.is_file():
- logger.info(f"[ToolHub] 检测到本地文件 images[{idx}]={img},上传到 OSS...")
- cdn_url = await _upload_to_oss(str(p.resolve()))
- if cdn_url:
- new_images.append(cdn_url)
- logger.info(f"[ToolHub] images[{idx}] 已替换为 CDN URL: {cdn_url}")
- else:
- new_images.append(img)
- else:
- new_images.append(img)
- except Exception as e:
- logger.warning(f"[ToolHub] images[{idx}] 路径处理失败: {e}")
- new_images.append(img)
- else:
- new_images.append(img)
- processed["images"] = new_images
- return processed
- # ── 工具实现 ──────────────────────────────────────────
- @tool(
- display={
- "zh": {"name": "ToolHub 健康检查", "params": {}},
- "en": {"name": "ToolHub Health Check", "params": {}},
- }
- )
- async def toolhub_health() -> ToolResult:
- """检查 ToolHub 远程工具库服务是否可用
- 检查 ToolHub 服务的健康状态,确认服务是否正常运行。
- 建议在调用其他 toolhub 工具之前先检查。
- Returns:
- ToolResult 包含服务健康状态信息
- """
- try:
- async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT, trust_env=False) as client:
- resp = await client.get(f"{TOOLHUB_BASE_URL}/health")
- resp.raise_for_status()
- data = resp.json()
- return ToolResult(
- title="ToolHub 健康检查",
- output=json.dumps(data, ensure_ascii=False, indent=2),
- long_term_memory=f"ToolHub service at {TOOLHUB_BASE_URL} is healthy.",
- )
- except httpx.ConnectError:
- return ToolResult(
- title="ToolHub 健康检查",
- output="",
- error=f"无法连接到 ToolHub 服务 {TOOLHUB_BASE_URL},请确认服务已启动。",
- )
- except Exception as e:
- return ToolResult(
- title="ToolHub 健康检查",
- output="",
- error=str(e),
- )
- @tool(
- display={
- "zh": {"name": "搜索 ToolHub 工具", "params": {"keyword": "搜索关键词"}},
- "en": {"name": "Search ToolHub", "params": {"keyword": "Search keyword"}},
- }
- )
- async def toolhub_search(keyword: Optional[str] = None) -> ToolResult:
- """搜索 ToolHub 远程工具库中可用的工具
- 从 ToolHub 工具库中获取可用工具列表,返回每个工具的完整信息,包括:
- tool_id、名称、分类、状态、参数列表(含类型、是否必填、默认值)、输出 schema、
- 分组信息(如 RunComfy 生命周期组)等。
- 调用 toolhub_call 之前,应先使用此工具了解目标工具的 tool_id 和所需参数。
- 不填 keyword 则返回所有工具。
- Args:
- keyword: 搜索关键词,用于过滤工具名称或描述(客户端过滤);为空则返回所有工具
- Returns:
- ToolResult 包含匹配的工具列表及其参数说明
- """
- try:
- async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT, trust_env=False) as client:
- resp = await client.get(f"{TOOLHUB_BASE_URL}/tools")
- resp.raise_for_status()
- data = resp.json()
- tools = data.get("tools", [])
- groups = data.get("groups", [])
- # 客户端关键词过滤
- if keyword:
- kw = keyword.lower()
- tools = [
- t for t in tools
- if kw in t.get("name", "").lower()
- or kw in t.get("description", "").lower()
- or kw in t.get("tool_id", "").lower()
- or kw in t.get("category", "").lower()
- ]
- total = len(tools)
- # 构建给 LLM 的结构化摘要
- summaries = []
- for t in tools:
- input_props = t.get("input_schema", {}).get("properties", {})
- required_fields = t.get("input_schema", {}).get("required", [])
- params_desc = []
- for name, info in input_props.items():
- req = "必填" if name in required_fields else "可选"
- desc = info.get("description", "")
- default_str = f", 默认={info['default']}" if info.get("default") is not None else ""
- enum_str = f", 可选值={info['enum']}" if info.get("enum") else ""
- params_desc.append(
- f" - {name} ({info.get('type','any')}, {req}): {desc}{default_str}{enum_str}"
- )
- group_str = ""
- if t.get("group_ids"):
- group_str = f"\n 所属分组: {', '.join(t['group_ids'])}"
- tool_block = (
- f"[{t['tool_id']}] {t['name']}\n"
- f" 状态: {t['state']} | 运行时: {t['backend_runtime']} | 分类: {t.get('category','')}"
- f"{group_str}\n"
- f" 描述: {t.get('description', '')}"
- )
- if params_desc:
- tool_block += "\n 参数:\n" + "\n".join(params_desc)
- else:
- tool_block += "\n 参数: 无"
- summaries.append(tool_block)
- # 分组使用说明
- group_summary = []
- for g in groups:
- group_summary.append(
- f"[组: {g['group_id']}] {g['name']}\n"
- f" 调用顺序: {' → '.join(g.get('usage_order', []))}\n"
- f" 说明: {g.get('usage_example', '')}"
- )
- output_parts = [f"共找到 {total} 个工具({'关键词: ' + keyword if keyword else '全量'}):\n"]
- output_parts.append("\n\n".join(summaries))
- if group_summary:
- output_parts.append("\n\n=== 工具分组(有顺序依赖)===\n" + "\n\n".join(group_summary))
- return ToolResult(
- title=f"ToolHub 搜索{f': {keyword}' if keyword else '(全量)'}",
- output="\n".join(output_parts),
- long_term_memory=(
- f"ToolHub 共 {total} 个工具: "
- + ", ".join(t["tool_id"] for t in tools[:15])
- + ("..." if total > 15 else "")
- ),
- )
- except Exception as e:
- return ToolResult(
- title="搜索 ToolHub 工具失败",
- output="",
- error=str(e),
- )
- @tool(
- display={
- "zh": {
- "name": "调用 ToolHub 工具",
- "params": {"tool_id": "工具ID", "params": "工具参数"},
- },
- "en": {
- "name": "Call ToolHub Tool",
- "params": {"tool_id": "Tool ID", "params": "Tool parameters"},
- },
- }
- )
- async def toolhub_call(
- tool_id: str,
- params: Optional[Dict[str, Any]] = None,
- ) -> ToolResult:
- """调用 ToolHub 远程工具库中的指定工具
- 通过 tool_id 调用 ToolHub 工具库中的某个工具,传入该工具所需的参数。
- 不同工具的参数不同,请先用 toolhub_search 查询目标工具的参数说明。
- 注意:部分工具为异步生命周期(如 RunComfy、即梦、FLUX),需要按分组顺序
- 依次调用多个工具(如先 launch → 再 executor → 再 stop)。
- 参数通过 params 字典传入,键名和类型需与工具定义一致。
- 例如调用图片拼接工具:
- tool_id="image_stitcher"
- params={"images": [...], "direction": "grid", "columns": 2}
- Args:
- tool_id: 要调用的工具 ID(从 toolhub_search 获取)
- params: 工具参数字典,键值对根据目标工具的参数定义决定
- Returns:
- ToolResult 包含工具执行结果
- """
- try:
- # 预处理参数:本地文件路径自动上传成 CDN URL
- params = await _preprocess_params(params or {})
- payload = {
- "tool_id": tool_id,
- "params": params,
- }
- async with httpx.AsyncClient(timeout=CALL_TIMEOUT, trust_env=False) as client:
- resp = await client.post(
- f"{TOOLHUB_BASE_URL}/run_tool", json=payload
- )
- resp.raise_for_status()
- data = resp.json()
- status = data.get("status")
- if status == "success":
- result = data.get("result", {})
- result_str = json.dumps(result, ensure_ascii=False, indent=2)
- # 提取图片并统一处理(下载 → 保存本地 → 上传 OSS → CDN URL)
- images = []
- if isinstance(result, dict):
- # 收集所有图片(单张 image 字段 + images 列表字段)
- raw_images = []
- has_single_image = False
- has_images_list = False
- if result.get("image") and isinstance(result["image"], str):
- raw_images.append(result["image"])
- has_single_image = True
- if result.get("images") and isinstance(result["images"], list):
- raw_images.extend(result["images"])
- has_images_list = True
- if raw_images:
- images, cdn_urls, saved_paths = await _process_images(raw_images, tool_id)
- # 构建文本输出(去掉原始图片数据)
- result_display = {k: v for k, v in result.items() if k not in ("image", "images")}
- if cdn_urls:
- result_display["cdn_urls"] = cdn_urls
- result_display["_note"] = (
- "图片已上传至 CDN(永久链接),可通过 cdn_urls 访问、传给其他工具或下载保存。"
- "同时也作为附件附加在本条消息中可直接查看。"
- )
- if saved_paths:
- result_display["saved_files"] = saved_paths
- result_display["image_count"] = len(images)
- result_str = json.dumps(result_display, ensure_ascii=False, indent=2)
- return ToolResult(
- title=f"ToolHub [{tool_id}] 执行成功",
- output=result_str,
- long_term_memory=f"Called ToolHub tool '{tool_id}' → success",
- images=images,
- )
- else:
- error_msg = data.get("error", "未知错误")
- return ToolResult(
- title=f"ToolHub [{tool_id}] 执行失败",
- output=json.dumps(data, ensure_ascii=False, indent=2),
- error=error_msg,
- )
- except httpx.TimeoutException:
- return ToolResult(
- title=f"ToolHub [{tool_id}] 调用超时",
- output="",
- error=f"调用工具 {tool_id} 超时({CALL_TIMEOUT:.0f}s),图像生成类工具可能需要更长时间。",
- )
- except Exception as e:
- return ToolResult(
- title=f"ToolHub [{tool_id}] 调用失败",
- output="",
- error=str(e),
- )
- @tool(
- display={
- "zh": {"name": "上传本地图片", "params": {"local_path": "本地文件路径"}},
- "en": {"name": "Upload Local Image", "params": {"local_path": "Local file path"}},
- }
- )
- async def image_uploader(local_path: str) -> ToolResult:
- """将本地图片上传到 OSS,返回可用的 CDN URL(image_url)
- 当你需要获取一张本地图片的 HTTP 链接时使用此工具。
- 传入本地文件路径,自动上传到 OSS 并返回永久 CDN URL。
- 注意:在调用 toolhub_call 时,image/image_url 等参数可以直接传本地路径,
- 系统会自动上传。此工具适用于你需要单独获取图片 URL 的场景。
- Args:
- local_path: 本地图片文件路径(相对路径或绝对路径均可)
- Returns:
- ToolResult 包含上传后的 CDN URL
- """
- import os
- from pathlib import Path
- p = Path(local_path)
- if not p.exists():
- return ToolResult(
- title="图片上传失败",
- output="",
- error=f"文件不存在: {local_path}",
- )
- if not p.is_file():
- return ToolResult(
- title="图片上传失败",
- output="",
- error=f"路径不是文件: {local_path}",
- )
- cdn_url = await _upload_to_oss(str(p.resolve()))
- if cdn_url:
- result = {
- "local_path": str(p.resolve()),
- "cdn_url": cdn_url,
- "file_size": os.path.getsize(p),
- }
- return ToolResult(
- title="图片上传成功",
- output=json.dumps(result, ensure_ascii=False, indent=2),
- long_term_memory=f"Uploaded {local_path} → {cdn_url}",
- )
- else:
- return ToolResult(
- title="图片上传失败",
- output="",
- error=f"OSS 上传失败,请检查文件路径和网络连接: {local_path}",
- )
- @tool(
- display={
- "zh": {"name": "下载图片到本地", "params": {"url": "图片URL", "save_path": "保存路径"}},
- "en": {"name": "Download Image", "params": {"url": "Image URL", "save_path": "Save path"}},
- }
- )
- async def image_downloader(url: str, save_path: str = "") -> ToolResult:
- """下载网络图片到本地文件
- 从 HTTP/HTTPS 链接下载图片并保存到本地。
- 适用于需要将 CDN 图片、生成结果等保存到本地目录的场景。
- Args:
- url: 图片的 HTTP/HTTPS 链接
- save_path: 本地保存路径(相对或绝对路径均可)。
- 如不指定,自动保存到当前输出目录,文件名从 URL 提取。
- Returns:
- ToolResult 包含下载后的本地文件路径和文件大小
- """
- import os
- from pathlib import Path
- from urllib.parse import urlparse, unquote
- if not url.startswith(("http://", "https://")):
- return ToolResult(
- title="图片下载失败",
- output="",
- error=f"无效的 URL(必须以 http:// 或 https:// 开头): {url}",
- )
- # 自动生成保存路径
- if not save_path:
- out_dir = _get_output_dir("download")
- # 从 URL 提取文件名
- url_path = urlparse(url).path
- filename = Path(unquote(url_path)).name if url_path else ""
- if not filename or not any(filename.lower().endswith(ext) for ext in (".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp")):
- filename = f"download_{int(time.time())}.png"
- save_path = str(out_dir / filename)
- # 确保目录存在
- p = Path(save_path)
- p.parent.mkdir(parents=True, exist_ok=True)
- try:
- async with httpx.AsyncClient(timeout=60.0, follow_redirects=True, trust_env=False) as client:
- resp = await client.get(url)
- resp.raise_for_status()
- p.write_bytes(resp.content)
- file_size = os.path.getsize(p)
- result = {
- "save_path": str(p.resolve()),
- "file_size": file_size,
- "source_url": url,
- }
- return ToolResult(
- title="图片下载成功",
- output=json.dumps(result, ensure_ascii=False, indent=2),
- long_term_memory=f"Downloaded {url} → {save_path}",
- )
- except httpx.HTTPStatusError as e:
- return ToolResult(
- title="图片下载失败",
- output="",
- error=f"HTTP 错误 {e.response.status_code}: {url}",
- )
- except Exception as e:
- return ToolResult(
- title="图片下载失败",
- output="",
- error=f"下载失败: {e}",
- )
|