registry.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. """
  2. Tool Registry - 工具注册表和装饰器
  3. 职责:
  4. 1. @tool 装饰器:自动注册工具并生成 Schema
  5. 2. 管理所有工具的 Schema 和实现
  6. 3. 路由工具调用到具体实现
  7. 4. 支持域名过滤、敏感数据处理、工具统计
  8. 从 Resonote/llm/tools/registry.py 抽取并扩展
  9. """
  10. import json
  11. import inspect
  12. import logging
  13. import time
  14. from typing import Any, Callable, Dict, List, Optional
  15. from agent.tools.url_matcher import filter_by_url
  16. logger = logging.getLogger(__name__)
  17. class ToolStats:
  18. """工具使用统计"""
  19. def __init__(self):
  20. self.call_count: int = 0
  21. self.success_count: int = 0
  22. self.failure_count: int = 0
  23. self.total_duration: float = 0.0
  24. self.last_called: Optional[float] = None
  25. @property
  26. def average_duration(self) -> float:
  27. """平均执行时间(秒)"""
  28. return self.total_duration / self.call_count if self.call_count > 0 else 0.0
  29. @property
  30. def success_rate(self) -> float:
  31. """成功率"""
  32. return self.success_count / self.call_count if self.call_count > 0 else 0.0
  33. def to_dict(self) -> Dict[str, Any]:
  34. return {
  35. "call_count": self.call_count,
  36. "success_count": self.success_count,
  37. "failure_count": self.failure_count,
  38. "average_duration": self.average_duration,
  39. "success_rate": self.success_rate,
  40. "last_called": self.last_called
  41. }
  42. class ToolRegistry:
  43. """工具注册表"""
  44. def __init__(self):
  45. self._tools: Dict[str, Dict[str, Any]] = {}
  46. self._stats: Dict[str, ToolStats] = {}
  47. def register(
  48. self,
  49. func: Callable,
  50. schema: Optional[Dict] = None,
  51. requires_confirmation: bool = False,
  52. editable_params: Optional[List[str]] = None,
  53. display: Optional[Dict[str, Dict[str, Any]]] = None,
  54. url_patterns: Optional[List[str]] = None
  55. ):
  56. """
  57. 注册工具
  58. Args:
  59. func: 工具函数
  60. schema: 工具 Schema(如果为 None,自动生成)
  61. requires_confirmation: 是否需要用户确认
  62. editable_params: 允许用户编辑的参数列表
  63. display: i18n 展示信息 {"zh": {"name": "xx", "params": {...}}, "en": {...}}
  64. url_patterns: URL 模式列表(如 ["*.google.com"],None = 无限制)
  65. """
  66. func_name = func.__name__
  67. # 如果没有提供 Schema,自动生成
  68. if schema is None:
  69. try:
  70. from agent.tools.schema import SchemaGenerator
  71. schema = SchemaGenerator.generate(func)
  72. except Exception as e:
  73. logger.error(f"Failed to generate schema for {func_name}: {e}")
  74. raise
  75. self._tools[func_name] = {
  76. "func": func,
  77. "schema": schema,
  78. "url_patterns": url_patterns,
  79. "ui_metadata": {
  80. "requires_confirmation": requires_confirmation,
  81. "editable_params": editable_params or [],
  82. "display": display or {}
  83. }
  84. }
  85. # 初始化统计
  86. self._stats[func_name] = ToolStats()
  87. logger.debug(
  88. f"[ToolRegistry] Registered: {func_name} "
  89. f"(requires_confirmation={requires_confirmation}, "
  90. f"editable_params={editable_params or []}, "
  91. f"url_patterns={url_patterns or 'none'})"
  92. )
  93. def is_registered(self, tool_name: str) -> bool:
  94. """检查工具是否已注册"""
  95. return tool_name in self._tools
  96. def get_schemas(self, tool_names: Optional[List[str]] = None) -> List[Dict]:
  97. """
  98. 获取工具 Schema
  99. Args:
  100. tool_names: 工具名称列表(None = 所有工具)
  101. Returns:
  102. OpenAI Tool Schema 列表
  103. """
  104. if tool_names is None:
  105. tool_names = list(self._tools.keys())
  106. schemas = []
  107. for name in tool_names:
  108. if name in self._tools:
  109. schemas.append(self._tools[name]["schema"])
  110. else:
  111. logger.warning(f"[ToolRegistry] Tool not found: {name}")
  112. return schemas
  113. def get_tool_names(self, current_url: Optional[str] = None) -> List[str]:
  114. """
  115. 获取工具名称列表(可选 URL 过滤)
  116. Args:
  117. current_url: 当前 URL(None = 返回所有工具)
  118. Returns:
  119. 工具名称列表
  120. """
  121. if current_url is None:
  122. return list(self._tools.keys())
  123. # 过滤工具
  124. tool_items = [
  125. {"name": name, "url_patterns": tool["url_patterns"]}
  126. for name, tool in self._tools.items()
  127. ]
  128. filtered = filter_by_url(tool_items, current_url, url_field="url_patterns")
  129. return [item["name"] for item in filtered]
  130. def get_schemas_for_url(self, current_url: Optional[str] = None) -> List[Dict]:
  131. """
  132. 根据当前 URL 获取匹配的工具 Schema
  133. Args:
  134. current_url: 当前 URL(None = 返回无 URL 限制的工具)
  135. Returns:
  136. 过滤后的工具 Schema 列表
  137. """
  138. tool_names = self.get_tool_names(current_url)
  139. return self.get_schemas(tool_names)
  140. async def execute(
  141. self,
  142. name: str,
  143. arguments: Dict[str, Any],
  144. uid: str = "",
  145. context: Optional[Dict[str, Any]] = None,
  146. sensitive_data: Optional[Dict[str, Any]] = None
  147. ) -> str:
  148. """
  149. 执行工具调用
  150. Args:
  151. name: 工具名称
  152. arguments: 工具参数
  153. uid: 用户ID(自动注入)
  154. context: 额外上下文
  155. sensitive_data: 敏感数据字典(用于替换 <secret> 占位符)
  156. Returns:
  157. JSON 字符串格式的结果
  158. """
  159. if name not in self._tools:
  160. error_msg = f"Unknown tool: {name}"
  161. logger.error(f"[ToolRegistry] {error_msg}")
  162. return json.dumps({"error": error_msg}, ensure_ascii=False)
  163. start_time = time.time()
  164. stats = self._stats[name]
  165. stats.call_count += 1
  166. stats.last_called = start_time
  167. try:
  168. func = self._tools[name]["func"]
  169. # 处理敏感数据占位符
  170. if sensitive_data:
  171. from agent.tools.sensitive import replace_sensitive_data
  172. current_url = context.get("page_url") if context else None
  173. arguments = replace_sensitive_data(arguments, sensitive_data, current_url)
  174. # 准备参数:只注入函数需要的参数
  175. kwargs = {**arguments}
  176. sig = inspect.signature(func)
  177. # 注入 uid(如果函数接受)
  178. if "uid" in sig.parameters:
  179. kwargs["uid"] = uid
  180. # 注入 context(如果函数接受)
  181. if "context" in sig.parameters:
  182. kwargs["context"] = context
  183. # 执行函数
  184. if inspect.iscoroutinefunction(func):
  185. result = await func(**kwargs)
  186. else:
  187. result = func(**kwargs)
  188. # 记录成功
  189. stats.success_count += 1
  190. duration = time.time() - start_time
  191. stats.total_duration += duration
  192. # 返回结果:ToolResult 转为可序列化格式
  193. if isinstance(result, str):
  194. return result
  195. # 处理 ToolResult 对象
  196. from agent.tools.models import ToolResult
  197. if isinstance(result, ToolResult):
  198. ret = {"text": result.to_llm_message()}
  199. # 保留images
  200. if result.images:
  201. ret["images"] = result.images
  202. # 保留tool_usage
  203. if result.tool_usage:
  204. ret["tool_usage"] = result.tool_usage
  205. # 向后兼容:只有text时返回字符串
  206. if len(ret) == 1:
  207. return ret["text"]
  208. return ret
  209. return json.dumps(result, ensure_ascii=False, indent=2)
  210. except Exception as e:
  211. # 记录失败
  212. stats.failure_count += 1
  213. duration = time.time() - start_time
  214. stats.total_duration += duration
  215. error_msg = f"Error executing tool '{name}': {str(e)}"
  216. logger.error(f"[ToolRegistry] {error_msg}")
  217. import traceback
  218. logger.error(traceback.format_exc())
  219. return json.dumps({"error": error_msg}, ensure_ascii=False)
  220. def get_stats(self, tool_name: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
  221. """
  222. 获取工具统计信息
  223. Args:
  224. tool_name: 工具名称(None = 所有工具)
  225. Returns:
  226. 统计信息字典
  227. """
  228. if tool_name:
  229. if tool_name in self._stats:
  230. return {tool_name: self._stats[tool_name].to_dict()}
  231. return {}
  232. return {name: stats.to_dict() for name, stats in self._stats.items()}
  233. def get_top_tools(self, limit: int = 10, by: str = "call_count") -> List[str]:
  234. """
  235. 获取排名靠前的工具
  236. Args:
  237. limit: 返回数量
  238. by: 排序依据(call_count, success_rate, average_duration)
  239. Returns:
  240. 工具名称列表
  241. """
  242. if by == "call_count":
  243. sorted_tools = sorted(
  244. self._stats.items(),
  245. key=lambda x: x[1].call_count,
  246. reverse=True
  247. )
  248. elif by == "success_rate":
  249. sorted_tools = sorted(
  250. self._stats.items(),
  251. key=lambda x: x[1].success_rate,
  252. reverse=True
  253. )
  254. elif by == "average_duration":
  255. sorted_tools = sorted(
  256. self._stats.items(),
  257. key=lambda x: x[1].average_duration,
  258. reverse=False # 越快越好
  259. )
  260. else:
  261. raise ValueError(f"Invalid sort by: {by}")
  262. return [name for name, _ in sorted_tools[:limit]]
  263. def check_confirmation_required(self, tool_calls: List[Dict]) -> bool:
  264. """检查是否有工具需要用户确认"""
  265. for tc in tool_calls:
  266. tool_name = tc.get("function", {}).get("name")
  267. if tool_name and tool_name in self._tools:
  268. if self._tools[tool_name]["ui_metadata"].get("requires_confirmation", False):
  269. return True
  270. return False
  271. def get_confirmation_flags(self, tool_calls: List[Dict]) -> List[bool]:
  272. """返回每个工具是否需要确认"""
  273. flags = []
  274. for tc in tool_calls:
  275. tool_name = tc.get("function", {}).get("name")
  276. if tool_name and tool_name in self._tools:
  277. flags.append(self._tools[tool_name]["ui_metadata"].get("requires_confirmation", False))
  278. else:
  279. flags.append(False)
  280. return flags
  281. def check_any_param_editable(self, tool_calls: List[Dict]) -> bool:
  282. """检查是否有任何工具允许参数编辑"""
  283. for tc in tool_calls:
  284. tool_name = tc.get("function", {}).get("name")
  285. if tool_name and tool_name in self._tools:
  286. editable_params = self._tools[tool_name]["ui_metadata"].get("editable_params", [])
  287. if editable_params:
  288. return True
  289. return False
  290. def get_editable_params_map(self, tool_calls: List[Dict]) -> Dict[str, List[str]]:
  291. """返回每个工具调用的可编辑参数列表"""
  292. params_map = {}
  293. for tc in tool_calls:
  294. tool_call_id = tc.get("id")
  295. tool_name = tc.get("function", {}).get("name")
  296. if tool_name and tool_name in self._tools:
  297. editable_params = self._tools[tool_name]["ui_metadata"].get("editable_params", [])
  298. params_map[tool_call_id] = editable_params
  299. else:
  300. params_map[tool_call_id] = []
  301. return params_map
  302. def get_ui_metadata(
  303. self,
  304. locale: str = "zh",
  305. tool_names: Optional[List[str]] = None
  306. ) -> Dict[str, Dict[str, Any]]:
  307. """
  308. 获取工具的UI元数据(用于前端展示)
  309. Returns:
  310. {
  311. "tool_name": {
  312. "display_name": "搜索笔记",
  313. "param_display_names": {"query": "搜索关键词"},
  314. "requires_confirmation": false,
  315. "editable_params": ["query"]
  316. }
  317. }
  318. """
  319. if tool_names is None:
  320. tool_names = list(self._tools.keys())
  321. metadata = {}
  322. for name in tool_names:
  323. if name not in self._tools:
  324. continue
  325. ui_meta = self._tools[name]["ui_metadata"]
  326. display = ui_meta.get("display", {}).get(locale, {})
  327. metadata[name] = {
  328. "display_name": display.get("name", name),
  329. "param_display_names": display.get("params", {}),
  330. "requires_confirmation": ui_meta.get("requires_confirmation", False),
  331. "editable_params": ui_meta.get("editable_params", [])
  332. }
  333. return metadata
  334. # 全局单例
  335. _global_registry = ToolRegistry()
  336. def tool(
  337. description: Optional[str] = None,
  338. param_descriptions: Optional[Dict[str, str]] = None,
  339. requires_confirmation: bool = False,
  340. editable_params: Optional[List[str]] = None,
  341. display: Optional[Dict[str, Dict[str, Any]]] = None,
  342. url_patterns: Optional[List[str]] = None
  343. ):
  344. """
  345. 工具装饰器 - 自动注册工具并生成 Schema
  346. Args:
  347. description: 函数描述(可选,从 docstring 提取)
  348. param_descriptions: 参数描述(可选,从 docstring 提取)
  349. requires_confirmation: 是否需要用户确认(默认 False)
  350. editable_params: 允许用户编辑的参数列表
  351. display: i18n 展示信息
  352. url_patterns: URL 模式列表(如 ["*.google.com"],None = 无限制)
  353. Example:
  354. @tool(
  355. editable_params=["query"],
  356. url_patterns=["*.google.com"],
  357. display={
  358. "zh": {"name": "搜索笔记", "params": {"query": "搜索关键词"}},
  359. "en": {"name": "Search Notes", "params": {"query": "Query"}}
  360. }
  361. )
  362. async def search_blocks(query: str, limit: int = 10, uid: str = "") -> str:
  363. '''搜索用户的笔记块'''
  364. ...
  365. """
  366. def decorator(func: Callable) -> Callable:
  367. # 注册到全局 registry
  368. _global_registry.register(
  369. func,
  370. requires_confirmation=requires_confirmation,
  371. editable_params=editable_params,
  372. display=display,
  373. url_patterns=url_patterns
  374. )
  375. return func
  376. return decorator
  377. def get_tool_registry() -> ToolRegistry:
  378. """获取全局工具注册表"""
  379. return _global_registry