registry.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. """
  2. Tool Registry - 工具注册表和装饰器
  3. 职责:
  4. 1. @tool 装饰器:自动注册工具并生成 Schema
  5. 2. 管理所有工具的 Schema 和实现
  6. 3. 路由工具调用到具体实现
  7. 4. 支持域名过滤、敏感数据处理、工具统计
  8. 从 Resonote/llm/tools/registry.py 抽取并扩展
  9. """
  10. import json
  11. import inspect
  12. import logging
  13. import time
  14. from typing import Any, Callable, Dict, List, Optional
  15. from reson_agent.tools.url_matcher import filter_by_url
  16. logger = logging.getLogger(__name__)
  17. class ToolStats:
  18. """工具使用统计"""
  19. def __init__(self):
  20. self.call_count: int = 0
  21. self.success_count: int = 0
  22. self.failure_count: int = 0
  23. self.total_duration: float = 0.0
  24. self.last_called: Optional[float] = None
  25. @property
  26. def average_duration(self) -> float:
  27. """平均执行时间(秒)"""
  28. return self.total_duration / self.call_count if self.call_count > 0 else 0.0
  29. @property
  30. def success_rate(self) -> float:
  31. """成功率"""
  32. return self.success_count / self.call_count if self.call_count > 0 else 0.0
  33. def to_dict(self) -> Dict[str, Any]:
  34. return {
  35. "call_count": self.call_count,
  36. "success_count": self.success_count,
  37. "failure_count": self.failure_count,
  38. "average_duration": self.average_duration,
  39. "success_rate": self.success_rate,
  40. "last_called": self.last_called
  41. }
  42. class ToolRegistry:
  43. """工具注册表"""
  44. def __init__(self):
  45. self._tools: Dict[str, Dict[str, Any]] = {}
  46. self._stats: Dict[str, ToolStats] = {}
  47. def register(
  48. self,
  49. func: Callable,
  50. schema: Optional[Dict] = None,
  51. requires_confirmation: bool = False,
  52. editable_params: Optional[List[str]] = None,
  53. display: Optional[Dict[str, Dict[str, Any]]] = None,
  54. url_patterns: Optional[List[str]] = None
  55. ):
  56. """
  57. 注册工具
  58. Args:
  59. func: 工具函数
  60. schema: 工具 Schema(如果为 None,自动生成)
  61. requires_confirmation: 是否需要用户确认
  62. editable_params: 允许用户编辑的参数列表
  63. display: i18n 展示信息 {"zh": {"name": "xx", "params": {...}}, "en": {...}}
  64. url_patterns: URL 模式列表(如 ["*.google.com"],None = 无限制)
  65. """
  66. func_name = func.__name__
  67. # 如果没有提供 Schema,自动生成
  68. if schema is None:
  69. try:
  70. from reson_agent.tools.schema import SchemaGenerator
  71. schema = SchemaGenerator.generate(func)
  72. except Exception as e:
  73. logger.error(f"Failed to generate schema for {func_name}: {e}")
  74. raise
  75. self._tools[func_name] = {
  76. "func": func,
  77. "schema": schema,
  78. "url_patterns": url_patterns,
  79. "ui_metadata": {
  80. "requires_confirmation": requires_confirmation,
  81. "editable_params": editable_params or [],
  82. "display": display or {}
  83. }
  84. }
  85. # 初始化统计
  86. self._stats[func_name] = ToolStats()
  87. logger.debug(
  88. f"[ToolRegistry] Registered: {func_name} "
  89. f"(requires_confirmation={requires_confirmation}, "
  90. f"editable_params={editable_params or []}, "
  91. f"url_patterns={url_patterns or 'none'})"
  92. )
  93. def is_registered(self, tool_name: str) -> bool:
  94. """检查工具是否已注册"""
  95. return tool_name in self._tools
  96. def get_schemas(self, tool_names: Optional[List[str]] = None) -> List[Dict]:
  97. """
  98. 获取工具 Schema
  99. Args:
  100. tool_names: 工具名称列表(None = 所有工具)
  101. Returns:
  102. OpenAI Tool Schema 列表
  103. """
  104. if tool_names is None:
  105. tool_names = list(self._tools.keys())
  106. schemas = []
  107. for name in tool_names:
  108. if name in self._tools:
  109. schemas.append(self._tools[name]["schema"])
  110. else:
  111. logger.warning(f"[ToolRegistry] Tool not found: {name}")
  112. return schemas
  113. def get_tool_names(self, current_url: Optional[str] = None) -> List[str]:
  114. """
  115. 获取工具名称列表(可选 URL 过滤)
  116. Args:
  117. current_url: 当前 URL(None = 返回所有工具)
  118. Returns:
  119. 工具名称列表
  120. """
  121. if current_url is None:
  122. return list(self._tools.keys())
  123. # 过滤工具
  124. tool_items = [
  125. {"name": name, "url_patterns": tool["url_patterns"]}
  126. for name, tool in self._tools.items()
  127. ]
  128. filtered = filter_by_url(tool_items, current_url, url_field="url_patterns")
  129. return [item["name"] for item in filtered]
  130. def get_schemas_for_url(self, current_url: Optional[str] = None) -> List[Dict]:
  131. """
  132. 根据当前 URL 获取匹配的工具 Schema
  133. Args:
  134. current_url: 当前 URL(None = 返回无 URL 限制的工具)
  135. Returns:
  136. 过滤后的工具 Schema 列表
  137. """
  138. tool_names = self.get_tool_names(current_url)
  139. return self.get_schemas(tool_names)
  140. async def execute(
  141. self,
  142. name: str,
  143. arguments: Dict[str, Any],
  144. uid: str = "",
  145. context: Optional[Dict[str, Any]] = None,
  146. sensitive_data: Optional[Dict[str, Any]] = None
  147. ) -> str:
  148. """
  149. 执行工具调用
  150. Args:
  151. name: 工具名称
  152. arguments: 工具参数
  153. uid: 用户ID(自动注入)
  154. context: 额外上下文
  155. sensitive_data: 敏感数据字典(用于替换 <secret> 占位符)
  156. Returns:
  157. JSON 字符串格式的结果
  158. """
  159. if name not in self._tools:
  160. error_msg = f"Unknown tool: {name}"
  161. logger.error(f"[ToolRegistry] {error_msg}")
  162. return json.dumps({"error": error_msg}, ensure_ascii=False)
  163. start_time = time.time()
  164. stats = self._stats[name]
  165. stats.call_count += 1
  166. stats.last_called = start_time
  167. try:
  168. func = self._tools[name]["func"]
  169. # 处理敏感数据占位符
  170. if sensitive_data:
  171. from reson_agent.tools.sensitive import replace_sensitive_data
  172. current_url = context.get("page_url") if context else None
  173. arguments = replace_sensitive_data(arguments, sensitive_data, current_url)
  174. # 注入 uid
  175. kwargs = {**arguments, "uid": uid}
  176. # 注入 context(如果函数接受)
  177. sig = inspect.signature(func)
  178. if "context" in sig.parameters:
  179. kwargs["context"] = context
  180. # 执行函数
  181. if inspect.iscoroutinefunction(func):
  182. result = await func(**kwargs)
  183. else:
  184. result = func(**kwargs)
  185. # 记录成功
  186. stats.success_count += 1
  187. duration = time.time() - start_time
  188. stats.total_duration += duration
  189. # 返回 JSON 字符串
  190. if isinstance(result, str):
  191. return result
  192. return json.dumps(result, ensure_ascii=False, indent=2)
  193. except Exception as e:
  194. # 记录失败
  195. stats.failure_count += 1
  196. duration = time.time() - start_time
  197. stats.total_duration += duration
  198. error_msg = f"Error executing tool '{name}': {str(e)}"
  199. logger.error(f"[ToolRegistry] {error_msg}")
  200. import traceback
  201. logger.error(traceback.format_exc())
  202. return json.dumps({"error": error_msg}, ensure_ascii=False)
  203. def get_stats(self, tool_name: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
  204. """
  205. 获取工具统计信息
  206. Args:
  207. tool_name: 工具名称(None = 所有工具)
  208. Returns:
  209. 统计信息字典
  210. """
  211. if tool_name:
  212. if tool_name in self._stats:
  213. return {tool_name: self._stats[tool_name].to_dict()}
  214. return {}
  215. return {name: stats.to_dict() for name, stats in self._stats.items()}
  216. def get_top_tools(self, limit: int = 10, by: str = "call_count") -> List[str]:
  217. """
  218. 获取排名靠前的工具
  219. Args:
  220. limit: 返回数量
  221. by: 排序依据(call_count, success_rate, average_duration)
  222. Returns:
  223. 工具名称列表
  224. """
  225. if by == "call_count":
  226. sorted_tools = sorted(
  227. self._stats.items(),
  228. key=lambda x: x[1].call_count,
  229. reverse=True
  230. )
  231. elif by == "success_rate":
  232. sorted_tools = sorted(
  233. self._stats.items(),
  234. key=lambda x: x[1].success_rate,
  235. reverse=True
  236. )
  237. elif by == "average_duration":
  238. sorted_tools = sorted(
  239. self._stats.items(),
  240. key=lambda x: x[1].average_duration,
  241. reverse=False # 越快越好
  242. )
  243. else:
  244. raise ValueError(f"Invalid sort by: {by}")
  245. return [name for name, _ in sorted_tools[:limit]]
  246. def check_confirmation_required(self, tool_calls: List[Dict]) -> bool:
  247. """检查是否有工具需要用户确认"""
  248. for tc in tool_calls:
  249. tool_name = tc.get("function", {}).get("name")
  250. if tool_name and tool_name in self._tools:
  251. if self._tools[tool_name]["ui_metadata"].get("requires_confirmation", False):
  252. return True
  253. return False
  254. def get_confirmation_flags(self, tool_calls: List[Dict]) -> List[bool]:
  255. """返回每个工具是否需要确认"""
  256. flags = []
  257. for tc in tool_calls:
  258. tool_name = tc.get("function", {}).get("name")
  259. if tool_name and tool_name in self._tools:
  260. flags.append(self._tools[tool_name]["ui_metadata"].get("requires_confirmation", False))
  261. else:
  262. flags.append(False)
  263. return flags
  264. def check_any_param_editable(self, tool_calls: List[Dict]) -> bool:
  265. """检查是否有任何工具允许参数编辑"""
  266. for tc in tool_calls:
  267. tool_name = tc.get("function", {}).get("name")
  268. if tool_name and tool_name in self._tools:
  269. editable_params = self._tools[tool_name]["ui_metadata"].get("editable_params", [])
  270. if editable_params:
  271. return True
  272. return False
  273. def get_editable_params_map(self, tool_calls: List[Dict]) -> Dict[str, List[str]]:
  274. """返回每个工具调用的可编辑参数列表"""
  275. params_map = {}
  276. for tc in tool_calls:
  277. tool_call_id = tc.get("id")
  278. tool_name = tc.get("function", {}).get("name")
  279. if tool_name and tool_name in self._tools:
  280. editable_params = self._tools[tool_name]["ui_metadata"].get("editable_params", [])
  281. params_map[tool_call_id] = editable_params
  282. else:
  283. params_map[tool_call_id] = []
  284. return params_map
  285. def get_ui_metadata(
  286. self,
  287. locale: str = "zh",
  288. tool_names: Optional[List[str]] = None
  289. ) -> Dict[str, Dict[str, Any]]:
  290. """
  291. 获取工具的UI元数据(用于前端展示)
  292. Returns:
  293. {
  294. "tool_name": {
  295. "display_name": "搜索笔记",
  296. "param_display_names": {"query": "搜索关键词"},
  297. "requires_confirmation": false,
  298. "editable_params": ["query"]
  299. }
  300. }
  301. """
  302. if tool_names is None:
  303. tool_names = list(self._tools.keys())
  304. metadata = {}
  305. for name in tool_names:
  306. if name not in self._tools:
  307. continue
  308. ui_meta = self._tools[name]["ui_metadata"]
  309. display = ui_meta.get("display", {}).get(locale, {})
  310. metadata[name] = {
  311. "display_name": display.get("name", name),
  312. "param_display_names": display.get("params", {}),
  313. "requires_confirmation": ui_meta.get("requires_confirmation", False),
  314. "editable_params": ui_meta.get("editable_params", [])
  315. }
  316. return metadata
  317. # 全局单例
  318. _global_registry = ToolRegistry()
  319. def tool(
  320. description: Optional[str] = None,
  321. param_descriptions: Optional[Dict[str, str]] = None,
  322. requires_confirmation: bool = False,
  323. editable_params: Optional[List[str]] = None,
  324. display: Optional[Dict[str, Dict[str, Any]]] = None,
  325. url_patterns: Optional[List[str]] = None
  326. ):
  327. """
  328. 工具装饰器 - 自动注册工具并生成 Schema
  329. Args:
  330. description: 函数描述(可选,从 docstring 提取)
  331. param_descriptions: 参数描述(可选,从 docstring 提取)
  332. requires_confirmation: 是否需要用户确认(默认 False)
  333. editable_params: 允许用户编辑的参数列表
  334. display: i18n 展示信息
  335. url_patterns: URL 模式列表(如 ["*.google.com"],None = 无限制)
  336. Example:
  337. @tool(
  338. editable_params=["query"],
  339. url_patterns=["*.google.com"],
  340. display={
  341. "zh": {"name": "搜索笔记", "params": {"query": "搜索关键词"}},
  342. "en": {"name": "Search Notes", "params": {"query": "Query"}}
  343. }
  344. )
  345. async def search_blocks(query: str, limit: int = 10, uid: str = "") -> str:
  346. '''搜索用户的笔记块'''
  347. ...
  348. """
  349. def decorator(func: Callable) -> Callable:
  350. # 注册到全局 registry
  351. _global_registry.register(
  352. func,
  353. requires_confirmation=requires_confirmation,
  354. editable_params=editable_params,
  355. display=display,
  356. url_patterns=url_patterns
  357. )
  358. return func
  359. return decorator
  360. def get_tool_registry() -> ToolRegistry:
  361. """获取全局工具注册表"""
  362. return _global_registry