| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460 |
- """
- Tool Registry - 工具注册表和装饰器
- 职责:
- 1. @tool 装饰器:自动注册工具并生成 Schema
- 2. 管理所有工具的 Schema 和实现
- 3. 路由工具调用到具体实现
- 4. 支持域名过滤、敏感数据处理、工具统计
- 从 Resonote/llm/tools/registry.py 抽取并扩展
- """
- import json
- import inspect
- import logging
- import time
- from typing import Any, Callable, Dict, List, Optional
- from agent.tools.url_matcher import filter_by_url
- logger = logging.getLogger(__name__)
- class ToolStats:
- """工具使用统计"""
- def __init__(self):
- self.call_count: int = 0
- self.success_count: int = 0
- self.failure_count: int = 0
- self.total_duration: float = 0.0
- self.last_called: Optional[float] = None
- @property
- def average_duration(self) -> float:
- """平均执行时间(秒)"""
- return self.total_duration / self.call_count if self.call_count > 0 else 0.0
- @property
- def success_rate(self) -> float:
- """成功率"""
- return self.success_count / self.call_count if self.call_count > 0 else 0.0
- def to_dict(self) -> Dict[str, Any]:
- return {
- "call_count": self.call_count,
- "success_count": self.success_count,
- "failure_count": self.failure_count,
- "average_duration": self.average_duration,
- "success_rate": self.success_rate,
- "last_called": self.last_called
- }
- class ToolRegistry:
- """工具注册表"""
- def __init__(self):
- self._tools: Dict[str, Dict[str, Any]] = {}
- self._stats: Dict[str, ToolStats] = {}
- def register(
- self,
- func: Callable,
- schema: Optional[Dict] = None,
- requires_confirmation: bool = False,
- editable_params: Optional[List[str]] = None,
- display: Optional[Dict[str, Dict[str, Any]]] = None,
- url_patterns: Optional[List[str]] = None
- ):
- """
- 注册工具
- Args:
- func: 工具函数
- schema: 工具 Schema(如果为 None,自动生成)
- requires_confirmation: 是否需要用户确认
- editable_params: 允许用户编辑的参数列表
- display: i18n 展示信息 {"zh": {"name": "xx", "params": {...}}, "en": {...}}
- url_patterns: URL 模式列表(如 ["*.google.com"],None = 无限制)
- """
- func_name = func.__name__
- # 如果没有提供 Schema,自动生成
- if schema is None:
- try:
- from agent.tools.schema import SchemaGenerator
- schema = SchemaGenerator.generate(func)
- except Exception as e:
- logger.error(f"Failed to generate schema for {func_name}: {e}")
- raise
- self._tools[func_name] = {
- "func": func,
- "schema": schema,
- "url_patterns": url_patterns,
- "ui_metadata": {
- "requires_confirmation": requires_confirmation,
- "editable_params": editable_params or [],
- "display": display or {}
- }
- }
- # 初始化统计
- self._stats[func_name] = ToolStats()
- logger.debug(
- f"[ToolRegistry] Registered: {func_name} "
- f"(requires_confirmation={requires_confirmation}, "
- f"editable_params={editable_params or []}, "
- f"url_patterns={url_patterns or 'none'})"
- )
- def is_registered(self, tool_name: str) -> bool:
- """检查工具是否已注册"""
- return tool_name in self._tools
- def get_schemas(self, tool_names: Optional[List[str]] = None) -> List[Dict]:
- """
- 获取工具 Schema
- Args:
- tool_names: 工具名称列表(None = 所有工具)
- Returns:
- OpenAI Tool Schema 列表
- """
- if tool_names is None:
- tool_names = list(self._tools.keys())
- schemas = []
- for name in tool_names:
- if name in self._tools:
- schemas.append(self._tools[name]["schema"])
- else:
- logger.warning(f"[ToolRegistry] Tool not found: {name}")
- return schemas
- def get_tool_names(self, current_url: Optional[str] = None) -> List[str]:
- """
- 获取工具名称列表(可选 URL 过滤)
- Args:
- current_url: 当前 URL(None = 返回所有工具)
- Returns:
- 工具名称列表
- """
- if current_url is None:
- return list(self._tools.keys())
- # 过滤工具
- tool_items = [
- {"name": name, "url_patterns": tool["url_patterns"]}
- for name, tool in self._tools.items()
- ]
- filtered = filter_by_url(tool_items, current_url, url_field="url_patterns")
- return [item["name"] for item in filtered]
- def get_schemas_for_url(self, current_url: Optional[str] = None) -> List[Dict]:
- """
- 根据当前 URL 获取匹配的工具 Schema
- Args:
- current_url: 当前 URL(None = 返回无 URL 限制的工具)
- Returns:
- 过滤后的工具 Schema 列表
- """
- tool_names = self.get_tool_names(current_url)
- return self.get_schemas(tool_names)
- async def execute(
- self,
- name: str,
- arguments: Dict[str, Any],
- uid: str = "",
- context: Optional[Dict[str, Any]] = None,
- sensitive_data: Optional[Dict[str, Any]] = None
- ) -> str:
- """
- 执行工具调用
- Args:
- name: 工具名称
- arguments: 工具参数
- uid: 用户ID(自动注入)
- context: 额外上下文
- sensitive_data: 敏感数据字典(用于替换 <secret> 占位符)
- Returns:
- JSON 字符串格式的结果
- """
- if name not in self._tools:
- error_msg = f"Unknown tool: {name}"
- logger.error(f"[ToolRegistry] {error_msg}")
- return json.dumps({"error": error_msg}, ensure_ascii=False)
- start_time = time.time()
- stats = self._stats[name]
- stats.call_count += 1
- stats.last_called = start_time
- try:
- func = self._tools[name]["func"]
- # 处理敏感数据占位符
- if sensitive_data:
- from agent.tools.sensitive import replace_sensitive_data
- current_url = context.get("page_url") if context else None
- arguments = replace_sensitive_data(arguments, sensitive_data, current_url)
- # 准备参数:只注入函数需要的参数
- kwargs = {**arguments}
- sig = inspect.signature(func)
- # 注入 uid(如果函数接受)
- if "uid" in sig.parameters:
- kwargs["uid"] = uid
- # 注入 context(如果函数接受)
- if "context" in sig.parameters:
- kwargs["context"] = context
- # 执行函数
- if inspect.iscoroutinefunction(func):
- result = await func(**kwargs)
- else:
- result = func(**kwargs)
- # 记录成功
- stats.success_count += 1
- duration = time.time() - start_time
- stats.total_duration += duration
- # 返回结果:ToolResult 转为可序列化格式
- if isinstance(result, str):
- return result
- # 处理 ToolResult 对象
- from agent.tools.models import ToolResult
- if isinstance(result, ToolResult):
- ret = {"text": result.to_llm_message()}
- # 保留images
- if result.images:
- ret["images"] = result.images
- # 保留tool_usage
- if result.tool_usage:
- ret["tool_usage"] = result.tool_usage
- # 向后兼容:只有text时返回字符串
- if len(ret) == 1:
- return ret["text"]
- return ret
- return json.dumps(result, ensure_ascii=False, indent=2)
- except Exception as e:
- # 记录失败
- stats.failure_count += 1
- duration = time.time() - start_time
- stats.total_duration += duration
- error_msg = f"Error executing tool '{name}': {str(e)}"
- logger.error(f"[ToolRegistry] {error_msg}")
- import traceback
- logger.error(traceback.format_exc())
- return json.dumps({"error": error_msg}, ensure_ascii=False)
- def get_stats(self, tool_name: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
- """
- 获取工具统计信息
- Args:
- tool_name: 工具名称(None = 所有工具)
- Returns:
- 统计信息字典
- """
- if tool_name:
- if tool_name in self._stats:
- return {tool_name: self._stats[tool_name].to_dict()}
- return {}
- return {name: stats.to_dict() for name, stats in self._stats.items()}
- def get_top_tools(self, limit: int = 10, by: str = "call_count") -> List[str]:
- """
- 获取排名靠前的工具
- Args:
- limit: 返回数量
- by: 排序依据(call_count, success_rate, average_duration)
- Returns:
- 工具名称列表
- """
- if by == "call_count":
- sorted_tools = sorted(
- self._stats.items(),
- key=lambda x: x[1].call_count,
- reverse=True
- )
- elif by == "success_rate":
- sorted_tools = sorted(
- self._stats.items(),
- key=lambda x: x[1].success_rate,
- reverse=True
- )
- elif by == "average_duration":
- sorted_tools = sorted(
- self._stats.items(),
- key=lambda x: x[1].average_duration,
- reverse=False # 越快越好
- )
- else:
- raise ValueError(f"Invalid sort by: {by}")
- return [name for name, _ in sorted_tools[:limit]]
- def check_confirmation_required(self, tool_calls: List[Dict]) -> bool:
- """检查是否有工具需要用户确认"""
- for tc in tool_calls:
- tool_name = tc.get("function", {}).get("name")
- if tool_name and tool_name in self._tools:
- if self._tools[tool_name]["ui_metadata"].get("requires_confirmation", False):
- return True
- return False
- def get_confirmation_flags(self, tool_calls: List[Dict]) -> List[bool]:
- """返回每个工具是否需要确认"""
- flags = []
- for tc in tool_calls:
- tool_name = tc.get("function", {}).get("name")
- if tool_name and tool_name in self._tools:
- flags.append(self._tools[tool_name]["ui_metadata"].get("requires_confirmation", False))
- else:
- flags.append(False)
- return flags
- def check_any_param_editable(self, tool_calls: List[Dict]) -> bool:
- """检查是否有任何工具允许参数编辑"""
- for tc in tool_calls:
- tool_name = tc.get("function", {}).get("name")
- if tool_name and tool_name in self._tools:
- editable_params = self._tools[tool_name]["ui_metadata"].get("editable_params", [])
- if editable_params:
- return True
- return False
- def get_editable_params_map(self, tool_calls: List[Dict]) -> Dict[str, List[str]]:
- """返回每个工具调用的可编辑参数列表"""
- params_map = {}
- for tc in tool_calls:
- tool_call_id = tc.get("id")
- tool_name = tc.get("function", {}).get("name")
- if tool_name and tool_name in self._tools:
- editable_params = self._tools[tool_name]["ui_metadata"].get("editable_params", [])
- params_map[tool_call_id] = editable_params
- else:
- params_map[tool_call_id] = []
- return params_map
- def get_ui_metadata(
- self,
- locale: str = "zh",
- tool_names: Optional[List[str]] = None
- ) -> Dict[str, Dict[str, Any]]:
- """
- 获取工具的UI元数据(用于前端展示)
- Returns:
- {
- "tool_name": {
- "display_name": "搜索笔记",
- "param_display_names": {"query": "搜索关键词"},
- "requires_confirmation": false,
- "editable_params": ["query"]
- }
- }
- """
- if tool_names is None:
- tool_names = list(self._tools.keys())
- metadata = {}
- for name in tool_names:
- if name not in self._tools:
- continue
- ui_meta = self._tools[name]["ui_metadata"]
- display = ui_meta.get("display", {}).get(locale, {})
- metadata[name] = {
- "display_name": display.get("name", name),
- "param_display_names": display.get("params", {}),
- "requires_confirmation": ui_meta.get("requires_confirmation", False),
- "editable_params": ui_meta.get("editable_params", [])
- }
- return metadata
- # 全局单例
- _global_registry = ToolRegistry()
- def tool(
- description: Optional[str] = None,
- param_descriptions: Optional[Dict[str, str]] = None,
- requires_confirmation: bool = False,
- editable_params: Optional[List[str]] = None,
- display: Optional[Dict[str, Dict[str, Any]]] = None,
- url_patterns: Optional[List[str]] = None
- ):
- """
- 工具装饰器 - 自动注册工具并生成 Schema
- Args:
- description: 函数描述(可选,从 docstring 提取)
- param_descriptions: 参数描述(可选,从 docstring 提取)
- requires_confirmation: 是否需要用户确认(默认 False)
- editable_params: 允许用户编辑的参数列表
- display: i18n 展示信息
- url_patterns: URL 模式列表(如 ["*.google.com"],None = 无限制)
- Example:
- @tool(
- editable_params=["query"],
- url_patterns=["*.google.com"],
- display={
- "zh": {"name": "搜索笔记", "params": {"query": "搜索关键词"}},
- "en": {"name": "Search Notes", "params": {"query": "Query"}}
- }
- )
- async def search_blocks(query: str, limit: int = 10, uid: str = "") -> str:
- '''搜索用户的笔记块'''
- ...
- """
- def decorator(func: Callable) -> Callable:
- # 注册到全局 registry
- _global_registry.register(
- func,
- requires_confirmation=requires_confirmation,
- editable_params=editable_params,
- display=display,
- url_patterns=url_patterns
- )
- return func
- return decorator
- def get_tool_registry() -> ToolRegistry:
- """获取全局工具注册表"""
- return _global_registry
|