|
|
@@ -5,234 +5,381 @@ Tool Registry - 工具注册表和装饰器
|
|
|
1. @tool 装饰器:自动注册工具并生成 Schema
|
|
|
2. 管理所有工具的 Schema 和实现
|
|
|
3. 路由工具调用到具体实现
|
|
|
+4. 支持域名过滤、敏感数据处理、工具统计
|
|
|
|
|
|
-从 Resonote/llm/tools/registry.py 抽取
|
|
|
+从 Resonote/llm/tools/registry.py 抽取并扩展
|
|
|
"""
|
|
|
|
|
|
import json
|
|
|
import inspect
|
|
|
import logging
|
|
|
+import time
|
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
|
|
+from reson_agent.tools.url_matcher import filter_by_url
|
|
|
+
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
+class ToolStats:
|
|
|
+ """工具使用统计"""
|
|
|
+
|
|
|
+ def __init__(self):
|
|
|
+ self.call_count: int = 0
|
|
|
+ self.success_count: int = 0
|
|
|
+ self.failure_count: int = 0
|
|
|
+ self.total_duration: float = 0.0
|
|
|
+ self.last_called: Optional[float] = None
|
|
|
+
|
|
|
+ @property
|
|
|
+ def average_duration(self) -> float:
|
|
|
+ """平均执行时间(秒)"""
|
|
|
+ return self.total_duration / self.call_count if self.call_count > 0 else 0.0
|
|
|
+
|
|
|
+ @property
|
|
|
+ def success_rate(self) -> float:
|
|
|
+ """成功率"""
|
|
|
+ return self.success_count / self.call_count if self.call_count > 0 else 0.0
|
|
|
+
|
|
|
+ def to_dict(self) -> Dict[str, Any]:
|
|
|
+ return {
|
|
|
+ "call_count": self.call_count,
|
|
|
+ "success_count": self.success_count,
|
|
|
+ "failure_count": self.failure_count,
|
|
|
+ "average_duration": self.average_duration,
|
|
|
+ "success_rate": self.success_rate,
|
|
|
+ "last_called": self.last_called
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
class ToolRegistry:
|
|
|
- """工具注册表"""
|
|
|
-
|
|
|
- def __init__(self):
|
|
|
- self._tools: Dict[str, Dict[str, Any]] = {}
|
|
|
-
|
|
|
- def register(
|
|
|
- self,
|
|
|
- func: Callable,
|
|
|
- schema: Optional[Dict] = None,
|
|
|
- requires_confirmation: bool = False,
|
|
|
- editable_params: Optional[List[str]] = None,
|
|
|
- display: Optional[Dict[str, Dict[str, Any]]] = None
|
|
|
- ):
|
|
|
- """
|
|
|
- 注册工具
|
|
|
-
|
|
|
- Args:
|
|
|
- func: 工具函数
|
|
|
- schema: 工具 Schema(如果为 None,自动生成)
|
|
|
- requires_confirmation: 是否需要用户确认
|
|
|
- editable_params: 允许用户编辑的参数列表
|
|
|
- display: i18n 展示信息 {"zh": {"name": "xx", "params": {...}}, "en": {...}}
|
|
|
- """
|
|
|
- func_name = func.__name__
|
|
|
-
|
|
|
- # 如果没有提供 Schema,自动生成
|
|
|
- if schema is None:
|
|
|
- try:
|
|
|
- from reson_agent.tools.schema import SchemaGenerator
|
|
|
- schema = SchemaGenerator.generate(func)
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"Failed to generate schema for {func_name}: {e}")
|
|
|
- raise
|
|
|
-
|
|
|
- self._tools[func_name] = {
|
|
|
- "func": func,
|
|
|
- "schema": schema,
|
|
|
- "ui_metadata": {
|
|
|
- "requires_confirmation": requires_confirmation,
|
|
|
- "editable_params": editable_params or [],
|
|
|
- "display": display or {}
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- logger.debug(
|
|
|
- f"[ToolRegistry] Registered: {func_name} "
|
|
|
- f"(requires_confirmation={requires_confirmation}, "
|
|
|
- f"editable_params={editable_params or []})"
|
|
|
- )
|
|
|
-
|
|
|
- def is_registered(self, tool_name: str) -> bool:
|
|
|
- """检查工具是否已注册"""
|
|
|
- return tool_name in self._tools
|
|
|
-
|
|
|
- def get_schemas(self, tool_names: Optional[List[str]] = None) -> List[Dict]:
|
|
|
- """
|
|
|
- 获取工具 Schema
|
|
|
-
|
|
|
- Args:
|
|
|
- tool_names: 工具名称列表(None = 所有工具)
|
|
|
-
|
|
|
- Returns:
|
|
|
- OpenAI Tool Schema 列表
|
|
|
- """
|
|
|
- if tool_names is None:
|
|
|
- tool_names = list(self._tools.keys())
|
|
|
-
|
|
|
- schemas = []
|
|
|
- for name in tool_names:
|
|
|
- if name in self._tools:
|
|
|
- schemas.append(self._tools[name]["schema"])
|
|
|
- else:
|
|
|
- logger.warning(f"[ToolRegistry] Tool not found: {name}")
|
|
|
-
|
|
|
- return schemas
|
|
|
-
|
|
|
- def get_tool_names(self) -> List[str]:
|
|
|
- """获取所有注册的工具名称"""
|
|
|
- return list(self._tools.keys())
|
|
|
-
|
|
|
- async def execute(
|
|
|
- self,
|
|
|
- name: str,
|
|
|
- arguments: Dict[str, Any],
|
|
|
- uid: str = "",
|
|
|
- context: Optional[Dict[str, Any]] = None
|
|
|
- ) -> str:
|
|
|
- """
|
|
|
- 执行工具调用
|
|
|
-
|
|
|
- Args:
|
|
|
- name: 工具名称
|
|
|
- arguments: 工具参数
|
|
|
- uid: 用户ID(自动注入)
|
|
|
- context: 额外上下文
|
|
|
-
|
|
|
- Returns:
|
|
|
- JSON 字符串格式的结果
|
|
|
- """
|
|
|
- if name not in self._tools:
|
|
|
- error_msg = f"Unknown tool: {name}"
|
|
|
- logger.error(f"[ToolRegistry] {error_msg}")
|
|
|
- return json.dumps({"error": error_msg}, ensure_ascii=False)
|
|
|
-
|
|
|
- try:
|
|
|
- func = self._tools[name]["func"]
|
|
|
-
|
|
|
- # 注入 uid
|
|
|
- kwargs = {**arguments, "uid": uid}
|
|
|
-
|
|
|
- # 注入 context(如果函数接受)
|
|
|
- sig = inspect.signature(func)
|
|
|
- if "context" in sig.parameters:
|
|
|
- kwargs["context"] = context
|
|
|
-
|
|
|
- # 执行函数
|
|
|
- if inspect.iscoroutinefunction(func):
|
|
|
- result = await func(**kwargs)
|
|
|
- else:
|
|
|
- result = func(**kwargs)
|
|
|
-
|
|
|
- # 返回 JSON 字符串
|
|
|
- if isinstance(result, str):
|
|
|
- return result
|
|
|
- return json.dumps(result, ensure_ascii=False, indent=2)
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- error_msg = f"Error executing tool '{name}': {str(e)}"
|
|
|
- logger.error(f"[ToolRegistry] {error_msg}")
|
|
|
- import traceback
|
|
|
- logger.error(traceback.format_exc())
|
|
|
- return json.dumps({"error": error_msg}, ensure_ascii=False)
|
|
|
-
|
|
|
- def check_confirmation_required(self, tool_calls: List[Dict]) -> bool:
|
|
|
- """检查是否有工具需要用户确认"""
|
|
|
- for tc in tool_calls:
|
|
|
- tool_name = tc.get("function", {}).get("name")
|
|
|
- if tool_name and tool_name in self._tools:
|
|
|
- if self._tools[tool_name]["ui_metadata"].get("requires_confirmation", False):
|
|
|
- return True
|
|
|
- return False
|
|
|
-
|
|
|
- def get_confirmation_flags(self, tool_calls: List[Dict]) -> List[bool]:
|
|
|
- """返回每个工具是否需要确认"""
|
|
|
- flags = []
|
|
|
- for tc in tool_calls:
|
|
|
- tool_name = tc.get("function", {}).get("name")
|
|
|
- if tool_name and tool_name in self._tools:
|
|
|
- flags.append(self._tools[tool_name]["ui_metadata"].get("requires_confirmation", False))
|
|
|
- else:
|
|
|
- flags.append(False)
|
|
|
- return flags
|
|
|
-
|
|
|
- def check_any_param_editable(self, tool_calls: List[Dict]) -> bool:
|
|
|
- """检查是否有任何工具允许参数编辑"""
|
|
|
- for tc in tool_calls:
|
|
|
- tool_name = tc.get("function", {}).get("name")
|
|
|
- if tool_name and tool_name in self._tools:
|
|
|
- editable_params = self._tools[tool_name]["ui_metadata"].get("editable_params", [])
|
|
|
- if editable_params:
|
|
|
- return True
|
|
|
- return False
|
|
|
-
|
|
|
- def get_editable_params_map(self, tool_calls: List[Dict]) -> Dict[str, List[str]]:
|
|
|
- """返回每个工具调用的可编辑参数列表"""
|
|
|
- params_map = {}
|
|
|
- for tc in tool_calls:
|
|
|
- tool_call_id = tc.get("id")
|
|
|
- tool_name = tc.get("function", {}).get("name")
|
|
|
-
|
|
|
- if tool_name and tool_name in self._tools:
|
|
|
- editable_params = self._tools[tool_name]["ui_metadata"].get("editable_params", [])
|
|
|
- params_map[tool_call_id] = editable_params
|
|
|
- else:
|
|
|
- params_map[tool_call_id] = []
|
|
|
-
|
|
|
- return params_map
|
|
|
-
|
|
|
- def get_ui_metadata(
|
|
|
- self,
|
|
|
- locale: str = "zh",
|
|
|
- tool_names: Optional[List[str]] = None
|
|
|
- ) -> Dict[str, Dict[str, Any]]:
|
|
|
- """
|
|
|
- 获取工具的UI元数据(用于前端展示)
|
|
|
-
|
|
|
- Returns:
|
|
|
- {
|
|
|
- "tool_name": {
|
|
|
- "display_name": "搜索笔记",
|
|
|
- "param_display_names": {"query": "搜索关键词"},
|
|
|
- "requires_confirmation": false,
|
|
|
- "editable_params": ["query"]
|
|
|
- }
|
|
|
- }
|
|
|
- """
|
|
|
- if tool_names is None:
|
|
|
- tool_names = list(self._tools.keys())
|
|
|
-
|
|
|
- metadata = {}
|
|
|
- for name in tool_names:
|
|
|
- if name not in self._tools:
|
|
|
- continue
|
|
|
-
|
|
|
- ui_meta = self._tools[name]["ui_metadata"]
|
|
|
- display = ui_meta.get("display", {}).get(locale, {})
|
|
|
-
|
|
|
- metadata[name] = {
|
|
|
- "display_name": display.get("name", name),
|
|
|
- "param_display_names": display.get("params", {}),
|
|
|
- "requires_confirmation": ui_meta.get("requires_confirmation", False),
|
|
|
- "editable_params": ui_meta.get("editable_params", [])
|
|
|
- }
|
|
|
-
|
|
|
- return metadata
|
|
|
+ """工具注册表"""
|
|
|
+
|
|
|
+ def __init__(self):
|
|
|
+ self._tools: Dict[str, Dict[str, Any]] = {}
|
|
|
+ self._stats: Dict[str, ToolStats] = {}
|
|
|
+
|
|
|
+ def register(
|
|
|
+ self,
|
|
|
+ func: Callable,
|
|
|
+ schema: Optional[Dict] = None,
|
|
|
+ requires_confirmation: bool = False,
|
|
|
+ editable_params: Optional[List[str]] = None,
|
|
|
+ display: Optional[Dict[str, Dict[str, Any]]] = None,
|
|
|
+ url_patterns: Optional[List[str]] = None
|
|
|
+ ):
|
|
|
+ """
|
|
|
+ 注册工具
|
|
|
+
|
|
|
+ Args:
|
|
|
+ func: 工具函数
|
|
|
+ schema: 工具 Schema(如果为 None,自动生成)
|
|
|
+ requires_confirmation: 是否需要用户确认
|
|
|
+ editable_params: 允许用户编辑的参数列表
|
|
|
+ display: i18n 展示信息 {"zh": {"name": "xx", "params": {...}}, "en": {...}}
|
|
|
+ url_patterns: URL 模式列表(如 ["*.google.com"],None = 无限制)
|
|
|
+ """
|
|
|
+ func_name = func.__name__
|
|
|
+
|
|
|
+ # 如果没有提供 Schema,自动生成
|
|
|
+ if schema is None:
|
|
|
+ try:
|
|
|
+ from reson_agent.tools.schema import SchemaGenerator
|
|
|
+ schema = SchemaGenerator.generate(func)
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Failed to generate schema for {func_name}: {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ self._tools[func_name] = {
|
|
|
+ "func": func,
|
|
|
+ "schema": schema,
|
|
|
+ "url_patterns": url_patterns,
|
|
|
+ "ui_metadata": {
|
|
|
+ "requires_confirmation": requires_confirmation,
|
|
|
+ "editable_params": editable_params or [],
|
|
|
+ "display": display or {}
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ # 初始化统计
|
|
|
+ self._stats[func_name] = ToolStats()
|
|
|
+
|
|
|
+ logger.debug(
|
|
|
+ f"[ToolRegistry] Registered: {func_name} "
|
|
|
+ f"(requires_confirmation={requires_confirmation}, "
|
|
|
+ f"editable_params={editable_params or []}, "
|
|
|
+ f"url_patterns={url_patterns or 'none'})"
|
|
|
+ )
|
|
|
+
|
|
|
+ def is_registered(self, tool_name: str) -> bool:
|
|
|
+ """检查工具是否已注册"""
|
|
|
+ return tool_name in self._tools
|
|
|
+
|
|
|
+ def get_schemas(self, tool_names: Optional[List[str]] = None) -> List[Dict]:
|
|
|
+ """
|
|
|
+ 获取工具 Schema
|
|
|
+
|
|
|
+ Args:
|
|
|
+ tool_names: 工具名称列表(None = 所有工具)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ OpenAI Tool Schema 列表
|
|
|
+ """
|
|
|
+ if tool_names is None:
|
|
|
+ tool_names = list(self._tools.keys())
|
|
|
+
|
|
|
+ schemas = []
|
|
|
+ for name in tool_names:
|
|
|
+ if name in self._tools:
|
|
|
+ schemas.append(self._tools[name]["schema"])
|
|
|
+ else:
|
|
|
+ logger.warning(f"[ToolRegistry] Tool not found: {name}")
|
|
|
+
|
|
|
+ return schemas
|
|
|
+
|
|
|
+ def get_tool_names(self, current_url: Optional[str] = None) -> List[str]:
|
|
|
+ """
|
|
|
+ 获取工具名称列表(可选 URL 过滤)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ current_url: 当前 URL(None = 返回所有工具)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 工具名称列表
|
|
|
+ """
|
|
|
+ if current_url is None:
|
|
|
+ return list(self._tools.keys())
|
|
|
+
|
|
|
+ # 过滤工具
|
|
|
+ tool_items = [
|
|
|
+ {"name": name, "url_patterns": tool["url_patterns"]}
|
|
|
+ for name, tool in self._tools.items()
|
|
|
+ ]
|
|
|
+ filtered = filter_by_url(tool_items, current_url, url_field="url_patterns")
|
|
|
+ return [item["name"] for item in filtered]
|
|
|
+
|
|
|
+ def get_schemas_for_url(self, current_url: Optional[str] = None) -> List[Dict]:
|
|
|
+ """
|
|
|
+ 根据当前 URL 获取匹配的工具 Schema
|
|
|
+
|
|
|
+ Args:
|
|
|
+ current_url: 当前 URL(None = 返回无 URL 限制的工具)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 过滤后的工具 Schema 列表
|
|
|
+ """
|
|
|
+ tool_names = self.get_tool_names(current_url)
|
|
|
+ return self.get_schemas(tool_names)
|
|
|
+
|
|
|
+ async def execute(
|
|
|
+ self,
|
|
|
+ name: str,
|
|
|
+ arguments: Dict[str, Any],
|
|
|
+ uid: str = "",
|
|
|
+ context: Optional[Dict[str, Any]] = None,
|
|
|
+ sensitive_data: Optional[Dict[str, Any]] = None
|
|
|
+ ) -> str:
|
|
|
+ """
|
|
|
+ 执行工具调用
|
|
|
+
|
|
|
+ Args:
|
|
|
+ name: 工具名称
|
|
|
+ arguments: 工具参数
|
|
|
+ uid: 用户ID(自动注入)
|
|
|
+ context: 额外上下文
|
|
|
+ sensitive_data: 敏感数据字典(用于替换 <secret> 占位符)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ JSON 字符串格式的结果
|
|
|
+ """
|
|
|
+ if name not in self._tools:
|
|
|
+ error_msg = f"Unknown tool: {name}"
|
|
|
+ logger.error(f"[ToolRegistry] {error_msg}")
|
|
|
+ return json.dumps({"error": error_msg}, ensure_ascii=False)
|
|
|
+
|
|
|
+ start_time = time.time()
|
|
|
+ stats = self._stats[name]
|
|
|
+ stats.call_count += 1
|
|
|
+ stats.last_called = start_time
|
|
|
+
|
|
|
+ try:
|
|
|
+ func = self._tools[name]["func"]
|
|
|
+
|
|
|
+ # 处理敏感数据占位符
|
|
|
+ if sensitive_data:
|
|
|
+ from reson_agent.tools.sensitive import replace_sensitive_data
|
|
|
+ current_url = context.get("page_url") if context else None
|
|
|
+ arguments = replace_sensitive_data(arguments, sensitive_data, current_url)
|
|
|
+
|
|
|
+ # 注入 uid
|
|
|
+ kwargs = {**arguments, "uid": uid}
|
|
|
+
|
|
|
+ # 注入 context(如果函数接受)
|
|
|
+ sig = inspect.signature(func)
|
|
|
+ if "context" in sig.parameters:
|
|
|
+ kwargs["context"] = context
|
|
|
+
|
|
|
+ # 执行函数
|
|
|
+ if inspect.iscoroutinefunction(func):
|
|
|
+ result = await func(**kwargs)
|
|
|
+ else:
|
|
|
+ result = func(**kwargs)
|
|
|
+
|
|
|
+ # 记录成功
|
|
|
+ stats.success_count += 1
|
|
|
+ duration = time.time() - start_time
|
|
|
+ stats.total_duration += duration
|
|
|
+
|
|
|
+ # 返回 JSON 字符串
|
|
|
+ if isinstance(result, str):
|
|
|
+ return result
|
|
|
+ return json.dumps(result, ensure_ascii=False, indent=2)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ # 记录失败
|
|
|
+ stats.failure_count += 1
|
|
|
+ duration = time.time() - start_time
|
|
|
+ stats.total_duration += duration
|
|
|
+
|
|
|
+ error_msg = f"Error executing tool '{name}': {str(e)}"
|
|
|
+ logger.error(f"[ToolRegistry] {error_msg}")
|
|
|
+ import traceback
|
|
|
+ logger.error(traceback.format_exc())
|
|
|
+ return json.dumps({"error": error_msg}, ensure_ascii=False)
|
|
|
+
|
|
|
+ def get_stats(self, tool_name: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 获取工具统计信息
|
|
|
+
|
|
|
+ Args:
|
|
|
+ tool_name: 工具名称(None = 所有工具)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 统计信息字典
|
|
|
+ """
|
|
|
+ if tool_name:
|
|
|
+ if tool_name in self._stats:
|
|
|
+ return {tool_name: self._stats[tool_name].to_dict()}
|
|
|
+ return {}
|
|
|
+
|
|
|
+ return {name: stats.to_dict() for name, stats in self._stats.items()}
|
|
|
+
|
|
|
+ def get_top_tools(self, limit: int = 10, by: str = "call_count") -> List[str]:
|
|
|
+ """
|
|
|
+ 获取排名靠前的工具
|
|
|
+
|
|
|
+ Args:
|
|
|
+ limit: 返回数量
|
|
|
+ by: 排序依据(call_count, success_rate, average_duration)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 工具名称列表
|
|
|
+ """
|
|
|
+ if by == "call_count":
|
|
|
+ sorted_tools = sorted(
|
|
|
+ self._stats.items(),
|
|
|
+ key=lambda x: x[1].call_count,
|
|
|
+ reverse=True
|
|
|
+ )
|
|
|
+ elif by == "success_rate":
|
|
|
+ sorted_tools = sorted(
|
|
|
+ self._stats.items(),
|
|
|
+ key=lambda x: x[1].success_rate,
|
|
|
+ reverse=True
|
|
|
+ )
|
|
|
+ elif by == "average_duration":
|
|
|
+ sorted_tools = sorted(
|
|
|
+ self._stats.items(),
|
|
|
+ key=lambda x: x[1].average_duration,
|
|
|
+ reverse=False # 越快越好
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Invalid sort by: {by}")
|
|
|
+
|
|
|
+ return [name for name, _ in sorted_tools[:limit]]
|
|
|
+
|
|
|
+ def check_confirmation_required(self, tool_calls: List[Dict]) -> bool:
|
|
|
+ """检查是否有工具需要用户确认"""
|
|
|
+ for tc in tool_calls:
|
|
|
+ tool_name = tc.get("function", {}).get("name")
|
|
|
+ if tool_name and tool_name in self._tools:
|
|
|
+ if self._tools[tool_name]["ui_metadata"].get("requires_confirmation", False):
|
|
|
+ return True
|
|
|
+ return False
|
|
|
+
|
|
|
+ def get_confirmation_flags(self, tool_calls: List[Dict]) -> List[bool]:
|
|
|
+ """返回每个工具是否需要确认"""
|
|
|
+ flags = []
|
|
|
+ for tc in tool_calls:
|
|
|
+ tool_name = tc.get("function", {}).get("name")
|
|
|
+ if tool_name and tool_name in self._tools:
|
|
|
+ flags.append(self._tools[tool_name]["ui_metadata"].get("requires_confirmation", False))
|
|
|
+ else:
|
|
|
+ flags.append(False)
|
|
|
+ return flags
|
|
|
+
|
|
|
+ def check_any_param_editable(self, tool_calls: List[Dict]) -> bool:
|
|
|
+ """检查是否有任何工具允许参数编辑"""
|
|
|
+ for tc in tool_calls:
|
|
|
+ tool_name = tc.get("function", {}).get("name")
|
|
|
+ if tool_name and tool_name in self._tools:
|
|
|
+ editable_params = self._tools[tool_name]["ui_metadata"].get("editable_params", [])
|
|
|
+ if editable_params:
|
|
|
+ return True
|
|
|
+ return False
|
|
|
+
|
|
|
+ def get_editable_params_map(self, tool_calls: List[Dict]) -> Dict[str, List[str]]:
|
|
|
+ """返回每个工具调用的可编辑参数列表"""
|
|
|
+ params_map = {}
|
|
|
+ for tc in tool_calls:
|
|
|
+ tool_call_id = tc.get("id")
|
|
|
+ tool_name = tc.get("function", {}).get("name")
|
|
|
+
|
|
|
+ if tool_name and tool_name in self._tools:
|
|
|
+ editable_params = self._tools[tool_name]["ui_metadata"].get("editable_params", [])
|
|
|
+ params_map[tool_call_id] = editable_params
|
|
|
+ else:
|
|
|
+ params_map[tool_call_id] = []
|
|
|
+
|
|
|
+ return params_map
|
|
|
+
|
|
|
+ def get_ui_metadata(
|
|
|
+ self,
|
|
|
+ locale: str = "zh",
|
|
|
+ tool_names: Optional[List[str]] = None
|
|
|
+ ) -> Dict[str, Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 获取工具的UI元数据(用于前端展示)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ {
|
|
|
+ "tool_name": {
|
|
|
+ "display_name": "搜索笔记",
|
|
|
+ "param_display_names": {"query": "搜索关键词"},
|
|
|
+ "requires_confirmation": false,
|
|
|
+ "editable_params": ["query"]
|
|
|
+ }
|
|
|
+ }
|
|
|
+ """
|
|
|
+ if tool_names is None:
|
|
|
+ tool_names = list(self._tools.keys())
|
|
|
+
|
|
|
+ metadata = {}
|
|
|
+ for name in tool_names:
|
|
|
+ if name not in self._tools:
|
|
|
+ continue
|
|
|
+
|
|
|
+ ui_meta = self._tools[name]["ui_metadata"]
|
|
|
+ display = ui_meta.get("display", {}).get(locale, {})
|
|
|
+
|
|
|
+ metadata[name] = {
|
|
|
+ "display_name": display.get("name", name),
|
|
|
+ "param_display_names": display.get("params", {}),
|
|
|
+ "requires_confirmation": ui_meta.get("requires_confirmation", False),
|
|
|
+ "editable_params": ui_meta.get("editable_params", [])
|
|
|
+ }
|
|
|
+
|
|
|
+ return metadata
|
|
|
|
|
|
|
|
|
# 全局单例
|
|
|
@@ -240,47 +387,51 @@ _global_registry = ToolRegistry()
|
|
|
|
|
|
|
|
|
def tool(
|
|
|
- description: Optional[str] = None,
|
|
|
- param_descriptions: Optional[Dict[str, str]] = None,
|
|
|
- requires_confirmation: bool = False,
|
|
|
- editable_params: Optional[List[str]] = None,
|
|
|
- display: Optional[Dict[str, Dict[str, Any]]] = None
|
|
|
+ description: Optional[str] = None,
|
|
|
+ param_descriptions: Optional[Dict[str, str]] = None,
|
|
|
+ requires_confirmation: bool = False,
|
|
|
+ editable_params: Optional[List[str]] = None,
|
|
|
+ display: Optional[Dict[str, Dict[str, Any]]] = None,
|
|
|
+ url_patterns: Optional[List[str]] = None
|
|
|
):
|
|
|
- """
|
|
|
- 工具装饰器 - 自动注册工具并生成 Schema
|
|
|
-
|
|
|
- Args:
|
|
|
- description: 函数描述(可选,从 docstring 提取)
|
|
|
- param_descriptions: 参数描述(可选,从 docstring 提取)
|
|
|
- requires_confirmation: 是否需要用户确认(默认 False)
|
|
|
- editable_params: 允许用户编辑的参数列表
|
|
|
- display: i18n 展示信息
|
|
|
-
|
|
|
- Example:
|
|
|
- @tool(
|
|
|
- editable_params=["query"],
|
|
|
- display={
|
|
|
- "zh": {"name": "搜索笔记", "params": {"query": "搜索关键词"}},
|
|
|
- "en": {"name": "Search Notes", "params": {"query": "Query"}}
|
|
|
- }
|
|
|
- )
|
|
|
- async def search_blocks(query: str, limit: int = 10, uid: str = "") -> str:
|
|
|
- '''搜索用户的笔记块'''
|
|
|
- ...
|
|
|
- """
|
|
|
- def decorator(func: Callable) -> Callable:
|
|
|
- # 注册到全局 registry
|
|
|
- _global_registry.register(
|
|
|
- func,
|
|
|
- requires_confirmation=requires_confirmation,
|
|
|
- editable_params=editable_params,
|
|
|
- display=display
|
|
|
- )
|
|
|
- return func
|
|
|
-
|
|
|
- return decorator
|
|
|
+ """
|
|
|
+ 工具装饰器 - 自动注册工具并生成 Schema
|
|
|
+
|
|
|
+ Args:
|
|
|
+ description: 函数描述(可选,从 docstring 提取)
|
|
|
+ param_descriptions: 参数描述(可选,从 docstring 提取)
|
|
|
+ requires_confirmation: 是否需要用户确认(默认 False)
|
|
|
+ editable_params: 允许用户编辑的参数列表
|
|
|
+ display: i18n 展示信息
|
|
|
+ url_patterns: URL 模式列表(如 ["*.google.com"],None = 无限制)
|
|
|
+
|
|
|
+ Example:
|
|
|
+ @tool(
|
|
|
+ editable_params=["query"],
|
|
|
+ url_patterns=["*.google.com"],
|
|
|
+ display={
|
|
|
+ "zh": {"name": "搜索笔记", "params": {"query": "搜索关键词"}},
|
|
|
+ "en": {"name": "Search Notes", "params": {"query": "Query"}}
|
|
|
+ }
|
|
|
+ )
|
|
|
+ async def search_blocks(query: str, limit: int = 10, uid: str = "") -> str:
|
|
|
+ '''搜索用户的笔记块'''
|
|
|
+ ...
|
|
|
+ """
|
|
|
+ def decorator(func: Callable) -> Callable:
|
|
|
+ # 注册到全局 registry
|
|
|
+ _global_registry.register(
|
|
|
+ func,
|
|
|
+ requires_confirmation=requires_confirmation,
|
|
|
+ editable_params=editable_params,
|
|
|
+ display=display,
|
|
|
+ url_patterns=url_patterns
|
|
|
+ )
|
|
|
+ return func
|
|
|
+
|
|
|
+ return decorator
|
|
|
|
|
|
|
|
|
def get_tool_registry() -> ToolRegistry:
|
|
|
- """获取全局工具注册表"""
|
|
|
- return _global_registry
|
|
|
+ """获取全局工具注册表"""
|
|
|
+ return _global_registry
|