""" Tool Registry - 工具注册表和装饰器 职责: 1. @tool 装饰器:自动注册工具并生成 Schema 2. 管理所有工具的 Schema 和实现 3. 路由工具调用到具体实现 4. 支持域名过滤、敏感数据处理、工具统计 从 Resonote/llm/tools/registry.py 抽取并扩展 """ import json import inspect import logging import time from typing import Any, Callable, Dict, List, Optional from agent.tools.url_matcher import filter_by_url logger = logging.getLogger(__name__) class ToolStats: """工具使用统计""" def __init__(self): self.call_count: int = 0 self.success_count: int = 0 self.failure_count: int = 0 self.total_duration: float = 0.0 self.last_called: Optional[float] = None @property def average_duration(self) -> float: """平均执行时间(秒)""" return self.total_duration / self.call_count if self.call_count > 0 else 0.0 @property def success_rate(self) -> float: """成功率""" return self.success_count / self.call_count if self.call_count > 0 else 0.0 def to_dict(self) -> Dict[str, Any]: return { "call_count": self.call_count, "success_count": self.success_count, "failure_count": self.failure_count, "average_duration": self.average_duration, "success_rate": self.success_rate, "last_called": self.last_called } class ToolRegistry: """工具注册表""" def __init__(self): self._tools: Dict[str, Dict[str, Any]] = {} self._stats: Dict[str, ToolStats] = {} def register( self, func: Callable, schema: Optional[Dict] = None, requires_confirmation: bool = False, editable_params: Optional[List[str]] = None, display: Optional[Dict[str, Dict[str, Any]]] = None, url_patterns: Optional[List[str]] = None, hidden_params: Optional[List[str]] = None, inject_params: Optional[Dict[str, Any]] = None ): """ 注册工具 Args: func: 工具函数 schema: 工具 Schema(如果为 None,自动生成) requires_confirmation: 是否需要用户确认 editable_params: 允许用户编辑的参数列表 display: i18n 展示信息 {"zh": {"name": "xx", "params": {...}}, "en": {...}} url_patterns: URL 模式列表(如 ["*.google.com"],None = 无限制) hidden_params: 隐藏参数列表(不生成 schema,LLM 看不到) inject_params: 注入参数规则 {param_name: injector_func} """ func_name = func.__name__ # 如果没有提供 Schema,自动生成 if schema is None: try: from agent.tools.schema import SchemaGenerator schema = SchemaGenerator.generate(func, hidden_params=hidden_params or []) except Exception as e: logger.error(f"Failed to generate schema for {func_name}: {e}") raise self._tools[func_name] = { "func": func, "schema": schema, "url_patterns": url_patterns, "hidden_params": hidden_params or [], "inject_params": inject_params or {}, "ui_metadata": { "requires_confirmation": requires_confirmation, "editable_params": editable_params or [], "display": display or {} } } # 初始化统计 self._stats[func_name] = ToolStats() logger.debug( f"[ToolRegistry] Registered: {func_name} " f"(requires_confirmation={requires_confirmation}, " f"editable_params={editable_params or []}, " f"url_patterns={url_patterns or 'none'})" ) @staticmethod def _resolve_key_path(context: Dict[str, Any], key_path: str) -> Any: """ 从 context 中按路径取值。 支持 "obj.field" 格式:第一段从 context dict 取值,后续段用 getattr。 例如 "knowledge_config.default_tags" → context["knowledge_config"].default_tags Args: context: 上下文字典 key_path: 取值路径 Returns: 取到的值,路径无效返回 None """ parts = key_path.split(".") value = context.get(parts[0]) for part in parts[1:]: if value is None: return None value = getattr(value, part, None) return value def is_registered(self, tool_name: str) -> bool: """检查工具是否已注册""" return tool_name in self._tools def get_schemas(self, tool_names: Optional[List[str]] = None) -> List[Dict]: """ 获取工具 Schema Args: tool_names: 工具名称列表(None = 所有工具) Returns: OpenAI Tool Schema 列表 """ if tool_names is None: tool_names = list(self._tools.keys()) schemas = [] for name in tool_names: if name in self._tools: schemas.append(self._tools[name]["schema"]) else: logger.warning(f"[ToolRegistry] Tool not found: {name}") return schemas def get_tool_names(self, current_url: Optional[str] = None) -> List[str]: """ 获取工具名称列表(可选 URL 过滤) Args: current_url: 当前 URL(None = 返回所有工具) Returns: 工具名称列表 """ if current_url is None: return list(self._tools.keys()) # 过滤工具 tool_items = [ {"name": name, "url_patterns": tool["url_patterns"]} for name, tool in self._tools.items() ] filtered = filter_by_url(tool_items, current_url, url_field="url_patterns") return [item["name"] for item in filtered] def get_schemas_for_url(self, current_url: Optional[str] = None) -> List[Dict]: """ 根据当前 URL 获取匹配的工具 Schema Args: current_url: 当前 URL(None = 返回无 URL 限制的工具) Returns: 过滤后的工具 Schema 列表 """ tool_names = self.get_tool_names(current_url) return self.get_schemas(tool_names) async def execute( self, name: str, arguments: Dict[str, Any], uid: str = "", context: Optional[Dict[str, Any]] = None, sensitive_data: Optional[Dict[str, Any]] = None, inject_values: Optional[Dict[str, Any]] = None ) -> str: """ 执行工具调用 Args: name: 工具名称 arguments: 工具参数 uid: 用户ID(自动注入) context: 额外上下文 sensitive_data: 敏感数据字典(用于替换 占位符) Returns: JSON 字符串格式的结果 """ if name not in self._tools: error_msg = f"Unknown tool: {name}" logger.error(f"[ToolRegistry] {error_msg}") return json.dumps({"error": error_msg}, ensure_ascii=False) start_time = time.time() stats = self._stats[name] stats.call_count += 1 stats.last_called = start_time try: func = self._tools[name]["func"] tool_info = self._tools[name] # 处理敏感数据占位符 if sensitive_data: from agent.tools.sensitive import replace_sensitive_data current_url = context.get("page_url") if context else None arguments = replace_sensitive_data(arguments, sensitive_data, current_url) # 准备参数:只注入函数需要的参数 kwargs = {**arguments} sig = inspect.signature(func) # 注入隐藏参数(hidden_params) hidden_params = tool_info.get("hidden_params", []) if "uid" in hidden_params and "uid" in sig.parameters: kwargs["uid"] = uid if "context" in hidden_params and "context" in sig.parameters: kwargs["context"] = context # 注入参数(inject_params) inject_params = tool_info.get("inject_params", {}) for param_name, rule in inject_params.items(): if param_name not in sig.parameters: continue if not isinstance(rule, dict) or "mode" not in rule: # 兼容旧格式:直接值或 callable if param_name not in kwargs or kwargs[param_name] is None: kwargs[param_name] = rule() if callable(rule) else rule continue mode = rule["mode"] key_path = rule.get("key") # 从 context 中按路径取值 value = self._resolve_key_path(context, key_path) if key_path and context else None if value is None: continue if mode == "default": # 默认值模式:LLM 未提供则注入 if param_name not in kwargs or kwargs[param_name] is None: kwargs[param_name] = value elif mode == "merge": # 合并模式:框架值始终保留,LLM 可追加新内容 llm_value = kwargs.get(param_name) if isinstance(value, dict): # dict: LLM 追加新 key,同名 key 以框架值为准 kwargs[param_name] = {**(llm_value or {}), **value} elif isinstance(value, list): # list: 合并去重 kwargs[param_name] = list(set((llm_value or []) + value)) else: kwargs[param_name] = value # 执行函数 if inspect.iscoroutinefunction(func): result = await func(**kwargs) else: result = func(**kwargs) # 记录成功 stats.success_count += 1 duration = time.time() - start_time stats.total_duration += duration # 返回结果:ToolResult 转为可序列化格式 if isinstance(result, str): return result # 处理 ToolResult 对象 from agent.tools.models import ToolResult if isinstance(result, ToolResult): ret = {"text": result.to_llm_message()} # 保留images if result.images: ret["images"] = result.images # 保留tool_usage if result.tool_usage: ret["tool_usage"] = result.tool_usage # 向后兼容:只有text时返回字符串 if len(ret) == 1: return ret["text"] return ret return json.dumps(result, ensure_ascii=False, indent=2) except Exception as e: # 记录失败 stats.failure_count += 1 duration = time.time() - start_time stats.total_duration += duration error_msg = f"Error executing tool '{name}': {str(e)}" logger.error(f"[ToolRegistry] {error_msg}") import traceback logger.error(traceback.format_exc()) return json.dumps({"error": error_msg}, ensure_ascii=False) def get_stats(self, tool_name: Optional[str] = None) -> Dict[str, Dict[str, Any]]: """ 获取工具统计信息 Args: tool_name: 工具名称(None = 所有工具) Returns: 统计信息字典 """ if tool_name: if tool_name in self._stats: return {tool_name: self._stats[tool_name].to_dict()} return {} return {name: stats.to_dict() for name, stats in self._stats.items()} def get_top_tools(self, limit: int = 10, by: str = "call_count") -> List[str]: """ 获取排名靠前的工具 Args: limit: 返回数量 by: 排序依据(call_count, success_rate, average_duration) Returns: 工具名称列表 """ if by == "call_count": sorted_tools = sorted( self._stats.items(), key=lambda x: x[1].call_count, reverse=True ) elif by == "success_rate": sorted_tools = sorted( self._stats.items(), key=lambda x: x[1].success_rate, reverse=True ) elif by == "average_duration": sorted_tools = sorted( self._stats.items(), key=lambda x: x[1].average_duration, reverse=False # 越快越好 ) else: raise ValueError(f"Invalid sort by: {by}") return [name for name, _ in sorted_tools[:limit]] def check_confirmation_required(self, tool_calls: List[Dict]) -> bool: """检查是否有工具需要用户确认""" for tc in tool_calls: tool_name = tc.get("function", {}).get("name") if tool_name and tool_name in self._tools: if self._tools[tool_name]["ui_metadata"].get("requires_confirmation", False): return True return False def get_confirmation_flags(self, tool_calls: List[Dict]) -> List[bool]: """返回每个工具是否需要确认""" flags = [] for tc in tool_calls: tool_name = tc.get("function", {}).get("name") if tool_name and tool_name in self._tools: flags.append(self._tools[tool_name]["ui_metadata"].get("requires_confirmation", False)) else: flags.append(False) return flags def check_any_param_editable(self, tool_calls: List[Dict]) -> bool: """检查是否有任何工具允许参数编辑""" for tc in tool_calls: tool_name = tc.get("function", {}).get("name") if tool_name and tool_name in self._tools: editable_params = self._tools[tool_name]["ui_metadata"].get("editable_params", []) if editable_params: return True return False def get_editable_params_map(self, tool_calls: List[Dict]) -> Dict[str, List[str]]: """返回每个工具调用的可编辑参数列表""" params_map = {} for tc in tool_calls: tool_call_id = tc.get("id") tool_name = tc.get("function", {}).get("name") if tool_name and tool_name in self._tools: editable_params = self._tools[tool_name]["ui_metadata"].get("editable_params", []) params_map[tool_call_id] = editable_params else: params_map[tool_call_id] = [] return params_map def get_ui_metadata( self, locale: str = "zh", tool_names: Optional[List[str]] = None ) -> Dict[str, Dict[str, Any]]: """ 获取工具的UI元数据(用于前端展示) Returns: { "tool_name": { "display_name": "搜索笔记", "param_display_names": {"query": "搜索关键词"}, "requires_confirmation": false, "editable_params": ["query"] } } """ if tool_names is None: tool_names = list(self._tools.keys()) metadata = {} for name in tool_names: if name not in self._tools: continue ui_meta = self._tools[name]["ui_metadata"] display = ui_meta.get("display", {}).get(locale, {}) metadata[name] = { "display_name": display.get("name", name), "param_display_names": display.get("params", {}), "requires_confirmation": ui_meta.get("requires_confirmation", False), "editable_params": ui_meta.get("editable_params", []) } return metadata # 全局单例 _global_registry = ToolRegistry() def tool( description: Optional[str] = None, param_descriptions: Optional[Dict[str, str]] = None, requires_confirmation: bool = False, editable_params: Optional[List[str]] = None, display: Optional[Dict[str, Dict[str, Any]]] = None, url_patterns: Optional[List[str]] = None, hidden_params: Optional[List[str]] = None, inject_params: Optional[Dict[str, Any]] = None ): """ 工具装饰器 - 自动注册工具并生成 Schema Args: description: 函数描述(可选,从 docstring 提取) param_descriptions: 参数描述(可选,从 docstring 提取) requires_confirmation: 是否需要用户确认(默认 False) editable_params: 允许用户编辑的参数列表 display: i18n 展示信息 url_patterns: URL 模式列表(如 ["*.google.com"],None = 无限制) hidden_params: 隐藏参数列表(不生成 schema,LLM 看不到) inject_params: 注入参数规则 {param_name: injector_func} Example: @tool( hidden_params=["context", "uid"], inject_params={ "owner": lambda ctx: ctx.config.knowledge.get_owner(), }, editable_params=["query"], url_patterns=["*.google.com"], display={ "zh": {"name": "搜索笔记", "params": {"query": "搜索关键词"}}, "en": {"name": "Search Notes", "params": {"query": "Query"}} } ) async def search_blocks( query: str, limit: int = 10, owner: Optional[str] = None, context: Optional[ToolContext] = None, uid: str = "" ) -> str: '''搜索用户的笔记块''' ... """ def decorator(func: Callable) -> Callable: # 注册到全局 registry _global_registry.register( func, requires_confirmation=requires_confirmation, editable_params=editable_params, display=display, url_patterns=url_patterns, hidden_params=hidden_params, inject_params=inject_params ) return func return decorator def get_tool_registry() -> ToolRegistry: """获取全局工具注册表""" return _global_registry