registry.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569
  1. """
  2. Tool Registry - 工具注册表和装饰器
  3. 职责:
  4. 1. @tool 装饰器:自动注册工具并生成 Schema
  5. 2. 管理所有工具的 Schema 和实现
  6. 3. 路由工具调用到具体实现
  7. 4. 支持域名过滤、敏感数据处理、工具统计
  8. 从 Resonote/llm/tools/registry.py 抽取并扩展
  9. """
  10. import json
  11. import inspect
  12. import logging
  13. import time
  14. from typing import Any, Callable, Dict, List, Optional
  15. from agent.tools.url_matcher import filter_by_url
  16. logger = logging.getLogger(__name__)
  17. class ToolStats:
  18. """工具使用统计"""
  19. def __init__(self):
  20. self.call_count: int = 0
  21. self.success_count: int = 0
  22. self.failure_count: int = 0
  23. self.total_duration: float = 0.0
  24. self.last_called: Optional[float] = None
  25. @property
  26. def average_duration(self) -> float:
  27. """平均执行时间(秒)"""
  28. return self.total_duration / self.call_count if self.call_count > 0 else 0.0
  29. @property
  30. def success_rate(self) -> float:
  31. """成功率"""
  32. return self.success_count / self.call_count if self.call_count > 0 else 0.0
  33. def to_dict(self) -> Dict[str, Any]:
  34. return {
  35. "call_count": self.call_count,
  36. "success_count": self.success_count,
  37. "failure_count": self.failure_count,
  38. "average_duration": self.average_duration,
  39. "success_rate": self.success_rate,
  40. "last_called": self.last_called
  41. }
  42. class ToolRegistry:
  43. """工具注册表"""
  44. def __init__(self):
  45. self._tools: Dict[str, Dict[str, Any]] = {}
  46. self._stats: Dict[str, ToolStats] = {}
  47. def register(
  48. self,
  49. func: Callable,
  50. schema: Optional[Dict] = None,
  51. requires_confirmation: bool = False,
  52. editable_params: Optional[List[str]] = None,
  53. display: Optional[Dict[str, Dict[str, Any]]] = None,
  54. url_patterns: Optional[List[str]] = None,
  55. hidden_params: Optional[List[str]] = None,
  56. inject_params: Optional[Dict[str, Any]] = None,
  57. groups: Optional[List[str]] = None,
  58. ):
  59. """
  60. 注册工具
  61. Args:
  62. func: 工具函数
  63. schema: 工具 Schema(如果为 None,自动生成)
  64. requires_confirmation: 是否需要用户确认
  65. editable_params: 允许用户编辑的参数列表
  66. display: i18n 展示信息 {"zh": {"name": "xx", "params": {...}}, "en": {...}}
  67. url_patterns: URL 模式列表(如 ["*.google.com"],None = 无限制)
  68. hidden_params: 隐藏参数列表(不生成 schema,LLM 看不到)
  69. inject_params: 注入参数规则 {param_name: injector_func}
  70. groups: 工具分组标签(如 ["core"]、["browser"]),用于 RunConfig.tool_groups 过滤
  71. """
  72. func_name = func.__name__
  73. # 如果没有提供 Schema,自动生成
  74. if schema is None:
  75. try:
  76. from agent.tools.schema import SchemaGenerator
  77. schema = SchemaGenerator.generate(func, hidden_params=hidden_params or [])
  78. except Exception as e:
  79. logger.error(f"Failed to generate schema for {func_name}: {e}")
  80. raise
  81. self._tools[func_name] = {
  82. "func": func,
  83. "schema": schema,
  84. "url_patterns": url_patterns,
  85. "hidden_params": hidden_params or [],
  86. "inject_params": inject_params or {},
  87. "groups": groups or [],
  88. "ui_metadata": {
  89. "requires_confirmation": requires_confirmation,
  90. "editable_params": editable_params or [],
  91. "display": display or {}
  92. }
  93. }
  94. # 初始化统计
  95. self._stats[func_name] = ToolStats()
  96. logger.debug(
  97. f"[ToolRegistry] Registered: {func_name} "
  98. f"(requires_confirmation={requires_confirmation}, "
  99. f"editable_params={editable_params or []}, "
  100. f"url_patterns={url_patterns or 'none'})"
  101. )
  102. @staticmethod
  103. def _resolve_key_path(context: Dict[str, Any], key_path: str) -> Any:
  104. """
  105. 从 context 中按路径取值。
  106. 支持 "obj.field" 格式:第一段从 context dict 取值,后续段用 getattr。
  107. 例如 "knowledge_config.default_tags" → context["knowledge_config"].default_tags
  108. Args:
  109. context: 上下文字典
  110. key_path: 取值路径
  111. Returns:
  112. 取到的值,路径无效返回 None
  113. """
  114. parts = key_path.split(".")
  115. value = context.get(parts[0])
  116. for part in parts[1:]:
  117. if value is None:
  118. return None
  119. value = getattr(value, part, None)
  120. return value
  121. def is_registered(self, tool_name: str) -> bool:
  122. """检查工具是否已注册"""
  123. return tool_name in self._tools
  124. def get_schemas(self, tool_names: Optional[List[str]] = None) -> List[Dict]:
  125. """
  126. 获取工具 Schema
  127. Args:
  128. tool_names: 工具名称列表(None = 所有工具)
  129. Returns:
  130. OpenAI Tool Schema 列表
  131. """
  132. if tool_names is None:
  133. tool_names = list(self._tools.keys())
  134. schemas = []
  135. for name in tool_names:
  136. if name in self._tools:
  137. schemas.append(self._tools[name]["schema"])
  138. else:
  139. logger.warning(f"[ToolRegistry] Tool not found: {name}")
  140. return schemas
  141. def get_tool_names(self, current_url: Optional[str] = None, groups: Optional[List[str]] = None) -> List[str]:
  142. """
  143. 获取工具名称列表(可选 URL 过滤 + group 过滤)
  144. Args:
  145. current_url: 当前 URL(None = 不过滤 URL)
  146. groups: 工具分组白名单(None = 不过滤 group,返回所有工具)
  147. Returns:
  148. 工具名称列表
  149. """
  150. # 1. group 过滤
  151. if groups is not None:
  152. group_set = set(groups)
  153. candidates = {
  154. name for name, tool in self._tools.items()
  155. if group_set & set(tool.get("groups", []))
  156. }
  157. else:
  158. candidates = set(self._tools.keys())
  159. # 2. URL 过滤
  160. if current_url is None:
  161. return list(candidates)
  162. tool_items = [
  163. {"name": name, "url_patterns": self._tools[name]["url_patterns"]}
  164. for name in candidates
  165. ]
  166. filtered = filter_by_url(tool_items, current_url, url_field="url_patterns")
  167. return [item["name"] for item in filtered]
  168. def get_available_groups(self) -> List[str]:
  169. """获取所有已注册的工具分组"""
  170. groups = set()
  171. for tool in self._tools.values():
  172. groups.update(tool.get("groups", []))
  173. return sorted(groups)
  174. def get_schemas_for_url(self, current_url: Optional[str] = None) -> List[Dict]:
  175. """
  176. 根据当前 URL 获取匹配的工具 Schema
  177. Args:
  178. current_url: 当前 URL(None = 返回无 URL 限制的工具)
  179. Returns:
  180. 过滤后的工具 Schema 列表
  181. """
  182. tool_names = self.get_tool_names(current_url)
  183. return self.get_schemas(tool_names)
  184. async def execute(
  185. self,
  186. name: str,
  187. arguments: Dict[str, Any],
  188. uid: str = "",
  189. context: Optional[Dict[str, Any]] = None,
  190. sensitive_data: Optional[Dict[str, Any]] = None,
  191. inject_values: Optional[Dict[str, Any]] = None
  192. ) -> str:
  193. """
  194. 执行工具调用
  195. Args:
  196. name: 工具名称
  197. arguments: 工具参数
  198. uid: 用户ID(自动注入)
  199. context: 额外上下文
  200. sensitive_data: 敏感数据字典(用于替换 <secret> 占位符)
  201. Returns:
  202. JSON 字符串格式的结果
  203. """
  204. if name not in self._tools:
  205. error_msg = f"Unknown tool: {name}"
  206. logger.error(f"[ToolRegistry] {error_msg}")
  207. return json.dumps({"error": error_msg}, ensure_ascii=False)
  208. start_time = time.time()
  209. stats = self._stats[name]
  210. stats.call_count += 1
  211. stats.last_called = start_time
  212. try:
  213. func = self._tools[name]["func"]
  214. tool_info = self._tools[name]
  215. # 处理敏感数据占位符
  216. if sensitive_data:
  217. from agent.tools.sensitive import replace_sensitive_data
  218. current_url = context.get("page_url") if context else None
  219. arguments = replace_sensitive_data(arguments, sensitive_data, current_url)
  220. # 准备参数:只注入函数需要的参数
  221. sig = inspect.signature(func)
  222. # 过滤掉函数签名中不存在的参数(如 Claude SDK 发送的 {"_": true} 占位符)
  223. valid_params = set(sig.parameters.keys())
  224. kwargs = {k: v for k, v in arguments.items() if k in valid_params}
  225. # 注入隐藏参数(hidden_params)
  226. hidden_params = tool_info.get("hidden_params", [])
  227. if "uid" in hidden_params and "uid" in sig.parameters:
  228. kwargs["uid"] = uid
  229. if "context" in hidden_params and "context" in sig.parameters:
  230. kwargs["context"] = context
  231. # 注入参数(inject_params)
  232. inject_params = tool_info.get("inject_params", {})
  233. for param_name, rule in inject_params.items():
  234. if param_name not in sig.parameters:
  235. continue
  236. if not isinstance(rule, dict) or "mode" not in rule:
  237. # 兼容旧格式:直接值或 callable
  238. if param_name not in kwargs or kwargs[param_name] is None:
  239. kwargs[param_name] = rule() if callable(rule) else rule
  240. continue
  241. mode = rule["mode"]
  242. key_path = rule.get("key")
  243. # 从 context 中按路径取值
  244. value = self._resolve_key_path(context, key_path) if key_path and context else None
  245. if value is None:
  246. continue
  247. if mode == "default":
  248. # 默认值模式:LLM 未提供则注入
  249. if param_name not in kwargs or kwargs[param_name] is None:
  250. kwargs[param_name] = value
  251. elif mode == "merge":
  252. # 合并模式:框架值始终保留,LLM 可追加新内容
  253. llm_value = kwargs.get(param_name)
  254. if isinstance(value, dict):
  255. # dict: LLM 追加新 key,同名 key 以框架值为准
  256. kwargs[param_name] = {**(llm_value or {}), **value}
  257. elif isinstance(value, list):
  258. # list: 合并去重
  259. kwargs[param_name] = list(set((llm_value or []) + value))
  260. else:
  261. kwargs[param_name] = value
  262. # 执行函数
  263. if inspect.iscoroutinefunction(func):
  264. result = await func(**kwargs)
  265. else:
  266. result = func(**kwargs)
  267. # 记录成功
  268. stats.success_count += 1
  269. duration = time.time() - start_time
  270. stats.total_duration += duration
  271. # 返回结果:ToolResult 转为可序列化格式
  272. if isinstance(result, str):
  273. return result
  274. # 处理 ToolResult 对象
  275. from agent.tools.models import ToolResult
  276. if isinstance(result, ToolResult):
  277. ret = {"text": result.to_llm_message()}
  278. # 保留images
  279. if result.images:
  280. ret["images"] = result.images
  281. # 保留tool_usage
  282. if result.tool_usage:
  283. ret["tool_usage"] = result.tool_usage
  284. # 向后兼容:只有text时返回字符串
  285. if len(ret) == 1:
  286. return ret["text"]
  287. return ret
  288. return json.dumps(result, ensure_ascii=False, indent=2)
  289. except Exception as e:
  290. # 记录失败
  291. stats.failure_count += 1
  292. duration = time.time() - start_time
  293. stats.total_duration += duration
  294. error_msg = f"Error executing tool '{name}': {str(e)}"
  295. logger.error(f"[ToolRegistry] {error_msg}")
  296. import traceback
  297. logger.error(traceback.format_exc())
  298. return json.dumps({"error": error_msg}, ensure_ascii=False)
  299. def get_stats(self, tool_name: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
  300. """
  301. 获取工具统计信息
  302. Args:
  303. tool_name: 工具名称(None = 所有工具)
  304. Returns:
  305. 统计信息字典
  306. """
  307. if tool_name:
  308. if tool_name in self._stats:
  309. return {tool_name: self._stats[tool_name].to_dict()}
  310. return {}
  311. return {name: stats.to_dict() for name, stats in self._stats.items()}
  312. def get_top_tools(self, limit: int = 10, by: str = "call_count") -> List[str]:
  313. """
  314. 获取排名靠前的工具
  315. Args:
  316. limit: 返回数量
  317. by: 排序依据(call_count, success_rate, average_duration)
  318. Returns:
  319. 工具名称列表
  320. """
  321. if by == "call_count":
  322. sorted_tools = sorted(
  323. self._stats.items(),
  324. key=lambda x: x[1].call_count,
  325. reverse=True
  326. )
  327. elif by == "success_rate":
  328. sorted_tools = sorted(
  329. self._stats.items(),
  330. key=lambda x: x[1].success_rate,
  331. reverse=True
  332. )
  333. elif by == "average_duration":
  334. sorted_tools = sorted(
  335. self._stats.items(),
  336. key=lambda x: x[1].average_duration,
  337. reverse=False # 越快越好
  338. )
  339. else:
  340. raise ValueError(f"Invalid sort by: {by}")
  341. return [name for name, _ in sorted_tools[:limit]]
  342. def check_confirmation_required(self, tool_calls: List[Dict]) -> bool:
  343. """检查是否有工具需要用户确认"""
  344. for tc in tool_calls:
  345. tool_name = tc.get("function", {}).get("name")
  346. if tool_name and tool_name in self._tools:
  347. if self._tools[tool_name]["ui_metadata"].get("requires_confirmation", False):
  348. return True
  349. return False
  350. def get_confirmation_flags(self, tool_calls: List[Dict]) -> List[bool]:
  351. """返回每个工具是否需要确认"""
  352. flags = []
  353. for tc in tool_calls:
  354. tool_name = tc.get("function", {}).get("name")
  355. if tool_name and tool_name in self._tools:
  356. flags.append(self._tools[tool_name]["ui_metadata"].get("requires_confirmation", False))
  357. else:
  358. flags.append(False)
  359. return flags
  360. def check_any_param_editable(self, tool_calls: List[Dict]) -> bool:
  361. """检查是否有任何工具允许参数编辑"""
  362. for tc in tool_calls:
  363. tool_name = tc.get("function", {}).get("name")
  364. if tool_name and tool_name in self._tools:
  365. editable_params = self._tools[tool_name]["ui_metadata"].get("editable_params", [])
  366. if editable_params:
  367. return True
  368. return False
  369. def get_editable_params_map(self, tool_calls: List[Dict]) -> Dict[str, List[str]]:
  370. """返回每个工具调用的可编辑参数列表"""
  371. params_map = {}
  372. for tc in tool_calls:
  373. tool_call_id = tc.get("id")
  374. tool_name = tc.get("function", {}).get("name")
  375. if tool_name and tool_name in self._tools:
  376. editable_params = self._tools[tool_name]["ui_metadata"].get("editable_params", [])
  377. params_map[tool_call_id] = editable_params
  378. else:
  379. params_map[tool_call_id] = []
  380. return params_map
  381. def get_ui_metadata(
  382. self,
  383. locale: str = "zh",
  384. tool_names: Optional[List[str]] = None
  385. ) -> Dict[str, Dict[str, Any]]:
  386. """
  387. 获取工具的UI元数据(用于前端展示)
  388. Returns:
  389. {
  390. "tool_name": {
  391. "display_name": "搜索笔记",
  392. "param_display_names": {"query": "搜索关键词"},
  393. "requires_confirmation": false,
  394. "editable_params": ["query"]
  395. }
  396. }
  397. """
  398. if tool_names is None:
  399. tool_names = list(self._tools.keys())
  400. metadata = {}
  401. for name in tool_names:
  402. if name not in self._tools:
  403. continue
  404. ui_meta = self._tools[name]["ui_metadata"]
  405. display = ui_meta.get("display", {}).get(locale, {})
  406. metadata[name] = {
  407. "display_name": display.get("name", name),
  408. "param_display_names": display.get("params", {}),
  409. "requires_confirmation": ui_meta.get("requires_confirmation", False),
  410. "editable_params": ui_meta.get("editable_params", [])
  411. }
  412. return metadata
  413. # 全局单例
  414. _global_registry = ToolRegistry()
  415. def tool(
  416. description: Optional[str] = None,
  417. param_descriptions: Optional[Dict[str, str]] = None,
  418. requires_confirmation: bool = False,
  419. editable_params: Optional[List[str]] = None,
  420. display: Optional[Dict[str, Dict[str, Any]]] = None,
  421. url_patterns: Optional[List[str]] = None,
  422. hidden_params: Optional[List[str]] = None,
  423. inject_params: Optional[Dict[str, Any]] = None,
  424. groups: Optional[List[str]] = None,
  425. ):
  426. """
  427. 工具装饰器 - 自动注册工具并生成 Schema
  428. Args:
  429. description: 函数描述(可选,从 docstring 提取)
  430. param_descriptions: 参数描述(可选,从 docstring 提取)
  431. requires_confirmation: 是否需要用户确认(默认 False)
  432. editable_params: 允许用户编辑的参数列表
  433. display: i18n 展示信息
  434. url_patterns: URL 模式列表(如 ["*.google.com"],None = 无限制)
  435. hidden_params: 隐藏参数列表(不生成 schema,LLM 看不到)
  436. inject_params: 注入参数规则 {param_name: injector_func}
  437. groups: 工具分组标签(如 ["core"]、["browser"]),用于 RunConfig.tool_groups 过滤
  438. Example:
  439. @tool(
  440. hidden_params=["context", "uid"],
  441. inject_params={
  442. "owner": lambda ctx: ctx.config.knowledge.get_owner(),
  443. },
  444. editable_params=["query"],
  445. url_patterns=["*.google.com"],
  446. display={
  447. "zh": {"name": "搜索笔记", "params": {"query": "搜索关键词"}},
  448. "en": {"name": "Search Notes", "params": {"query": "Query"}}
  449. }
  450. )
  451. async def search_blocks(
  452. query: str,
  453. limit: int = 10,
  454. owner: Optional[str] = None,
  455. context: Optional[ToolContext] = None,
  456. uid: str = ""
  457. ) -> str:
  458. '''搜索用户的笔记块'''
  459. ...
  460. """
  461. def decorator(func: Callable) -> Callable:
  462. # 注册到全局 registry
  463. _global_registry.register(
  464. func,
  465. requires_confirmation=requires_confirmation,
  466. editable_params=editable_params,
  467. display=display,
  468. url_patterns=url_patterns,
  469. hidden_params=hidden_params,
  470. inject_params=inject_params,
  471. groups=groups,
  472. )
  473. return func
  474. return decorator
  475. def get_tool_registry() -> ToolRegistry:
  476. """获取全局工具注册表"""
  477. return _global_registry