Преглед на файлове

refactor: inject_params and hidden_params

Talegorithm преди 2 дни
родител
ревизия
e81df46e04

+ 119 - 35
agent/docs/tools.md

@@ -25,7 +25,7 @@
 from reson_agent import tool, ToolResult, ToolContext
 
 @tool()
-async def my_tool(arg: str, ctx: ToolContext) -> ToolResult:
+async def my_tool(arg: str, context: Optional[ToolContext] = None) -> ToolResult:
     return ToolResult(
         title="Success",
         output="Result content"
@@ -44,16 +44,53 @@ async def my_tool(arg: str, ctx: ToolContext) -> ToolResult:
 1. 定义工具
    ↓ @tool() 装饰器
 2. 自动注册到 ToolRegistry
-   ↓ 生成 OpenAI Tool Schema
+   ↓ 生成 OpenAI Tool Schema(跳过 hidden_params)
 3. LLM 选择工具并生成参数
    ↓ registry.execute(name, args)
-4. 注入 uid 和 context
+4. 注入框架参数(hidden_params + inject_params)
    ↓ 调用工具函数
 5. 返回 ToolResult
    ↓ 转换为 LLM 消息
 6. 添加到对话历史
 ```
 
+### 参数注入机制
+
+工具参数分为三类:
+
+1. **业务参数**:LLM 可见,由 LLM 填写(如 `query`, `limit`)
+2. **隐藏参数**:LLM 不可见,框架自动注入(如 `context`, `uid`)
+3. **注入参数**:LLM 可见但有默认值,框架自动注入默认值(如 `owner`, `tags`)
+
+```python
+@tool(
+    hidden_params=["context", "uid"],  # 不生成 schema,LLM 看不到
+    inject_params={                     # 自动注入默认值
+        "owner": lambda ctx: ctx.config.knowledge.get_owner(),
+        "tags": lambda ctx, args: {**ctx.config.default_tags, **args.get("tags", {})},
+    }
+)
+async def knowledge_save(
+    task: str,                          # 业务参数:LLM 填写
+    content: str,                       # 业务参数:LLM 填写
+    types: List[str],                   # 业务参数:LLM 填写
+    tags: Optional[Dict] = None,        # 注入参数:LLM 可填,框架提供默认值
+    owner: Optional[str] = None,        # 注入参数:LLM 可填,框架提供默认值
+    context: Optional[ToolContext] = None,  # 隐藏参数:LLM 看不到
+    uid: str = "",                      # 隐藏参数:LLM 看不到
+) -> ToolResult:
+    """保存知识到知识库"""
+    ...
+```
+
+**注入时机**:
+- Schema 生成时:跳过 `hidden_params`,不暴露给 LLM
+- 工具执行前:注入 `hidden_params` 和 `inject_params` 的默认值
+
+**实现位置**:
+- Schema 生成:`agent/tools/schema.py:SchemaGenerator.generate()`
+- 参数注入:`agent/tools/registry.py:ToolRegistry.execute()`
+
 ---
 
 ## 定义工具
@@ -64,23 +101,24 @@ async def my_tool(arg: str, ctx: ToolContext) -> ToolResult:
 from reson_agent import tool
 
 @tool()
-async def hello(name: str, uid: str = "") -> str:
+async def hello(name: str) -> str:
     """向用户问好"""
     return f"Hello, {name}!"
 ```
 
 **要点**:
-- `uid` 参数由框架自动注入(用户不传递)
 - 可以是同步或异步函数
 - 返回值自动序列化为 JSON
+- 所有参数默认对 LLM 可见
 
-### 带完整注释
+### 带框架参数
 
 ```python
-@tool()
+@tool(hidden_params=["context", "uid"])
 async def search_notes(
     query: str,
     limit: int = 10,
+    context: Optional[ToolContext] = None,
     uid: str = ""
 ) -> str:
     """
@@ -89,14 +127,50 @@ async def search_notes(
     Args:
         query: 搜索关键词
         limit: 返回结果数量
+    """
+    # context 和 uid 由框架注入,LLM 看不到这两个参数
+    ...
+```
 
-    Returns:
-        JSON 格式的搜索结果
+### 带参数注入
+
+```python
+@tool(
+    hidden_params=["context"],
+    inject_params={
+        "owner": lambda ctx: ctx.config.knowledge.get_owner(),
+        "tags": lambda ctx, args: {**ctx.config.default_tags, **args.get("tags", {})},
+    }
+)
+async def knowledge_save(
+    task: str,
+    content: str,
+    types: List[str],
+    tags: Optional[Dict] = None,  # LLM 可填,框架提供默认值
+    owner: Optional[str] = None,  # LLM 可填,框架提供默认值
+    context: Optional[ToolContext] = None
+) -> ToolResult:
     """
-    # 支持自动从 docstring 提取 function description 和 parameter descriptions
+    保存知识
+
+    Args:
+        task: 任务描述
+        content: 知识内容
+        types: 知识类型
+        tags: 业务标签(可选,有默认值)
+        owner: 所有者(可选,有默认值)
+    """
+    # owner 和 tags 如果 LLM 未提供,框架会注入默认值
     ...
 ```
 
+**注入规则**:
+- `inject_params` 的 value 可以是:
+  - `lambda ctx: ...` - 从 context 计算
+  - `lambda ctx, args: ...` - 从 context 和已有参数计算
+  - 字符串 - 直接使用该值
+- 注入时机:工具执行前,使用 `setdefault` 注入(不覆盖 LLM 提供的值)
+
 ### 带 UI 元数据
 
 ```python
@@ -228,16 +302,16 @@ async def generate_report() -> ToolResult:
 
 ### 基本概念
 
-工具函数可以声明需要 `ToolContext` 参数,框架自动注入。
+工具函数可以声明需要 `ToolContext` 参数,框架自动注入。需要在 `@tool()` 装饰器中声明 `hidden_params=["context"]`,使其对 LLM 不可见。
 
 ```python
 from reson_agent import ToolContext
 
-@tool()
-async def get_current_state(ctx: ToolContext) -> ToolResult:
+@tool(hidden_params=["context"])
+async def get_current_state(context: Optional[ToolContext] = None) -> ToolResult:
     return ToolResult(
         title="Current state",
-        output=f"Trace ID: {ctx.trace_id}\nStep ID: {ctx.step_id}"
+        output=f"Trace ID: {context.trace_id}\nStep ID: {context.step_id}"
     )
 ```
 
@@ -250,56 +324,66 @@ class ToolContext(Protocol):
     step_id: str                # 当前 Step ID
     uid: Optional[str]          # 用户 ID
 
+    # 扩展字段(由 runner 注入)
+    store: Optional[TraceStore]     # Trace 存储
+    runner: Optional[AgentRunner]   # Runner 实例
+    goal_tree: Optional[GoalTree]   # 目标树
+    goal_id: Optional[str]          # 当前 Goal ID
+    config: Optional[RunConfig]     # 运行配置
+
     # 浏览器相关(Browser-Use 集成)
     browser_session: Optional[Any]      # 浏览器会话
     page_url: Optional[str]             # 当前页面 URL
     file_system: Optional[Any]          # 文件系统访问
     sensitive_data: Optional[Dict]      # 敏感数据
 
-    # 扩展字段
-    context: Optional[Dict[str, Any]]   # 额外上下文
+    # 额外上下文
+    context: Optional[Dict[str, Any]]   # 额外上下文数据
 ```
 
 ### 使用示例
 
 ```python
-@tool()
-async def analyze_current_page(ctx: ToolContext) -> ToolResult:
+@tool(hidden_params=["context"])
+async def analyze_current_page(context: Optional[ToolContext] = None) -> ToolResult:
     """分析当前浏览器页面"""
 
-    if not ctx.browser_session:
+    if not context or not context.browser_session:
         return ToolResult(
             title="Error",
             error="Browser session not available"
         )
 
     # 使用浏览器会话
-    page_content = await ctx.browser_session.get_content()
+    page_content = await context.browser_session.get_content()
 
     return ToolResult(
-        title=f"Analyzed {ctx.page_url}",
+        title=f"Analyzed {context.page_url}",
         output=page_content,
-        long_term_memory=f"Analyzed page at {ctx.page_url}"
+        long_term_memory=f"Analyzed page at {context.page_url}"
     )
 ```
 
 ### 创建 ToolContext
 
-```python
-from reson_agent import ToolContextImpl
+Runner 在执行工具时自动创建并注入 context:
 
-ctx = ToolContextImpl(
-    trace_id="trace_123",
-    step_id="step_456",
-    uid="user_789",
-    page_url="https://example.com"
-)
+```python
+# 在 AgentRunner._agent_loop 中
+context = {
+    "store": self.trace_store,
+    "trace_id": trace_id,
+    "goal_id": current_goal_id,
+    "runner": self,
+    "goal_tree": goal_tree,
+    "config": config,
+}
 
-# 执行工具
-result = await registry.execute(
-    "analyze_current_page",
-    arguments={},
-    context=ctx
+result = await self.tools.execute(
+    tool_name,
+    tool_args,
+    uid=config.uid or "",
+    context=context
 )
 ```
 

+ 1 - 1
agent/tools/builtin/bash.py

@@ -158,7 +158,7 @@ def _kill_process_tree(pid: int) -> None:
         pass
 
 
-@tool(description="执行 bash 命令")
+@tool(description="执行 bash 命令", hidden_params=["context"])
 async def bash_command(
     command: str,
     timeout: Optional[int] = None,

+ 4 - 0
agent/tools/builtin/feishu/chat.py

@@ -132,6 +132,7 @@ def update_unread_count(contact_name: str, increment: int = 1, reset: bool = Fal
 # ==================== 三、@tool 工具 ====================
 
 @tool(
+    hidden_params=["context"],
     display={
         "zh": {
             "name": "获取飞书联系人列表",
@@ -158,6 +159,7 @@ async def feishu_get_contact_list(context: Optional[ToolContext] = None) -> Tool
     )
 
 @tool(
+    hidden_params=["context"],
     display={
         "zh": {
             "name": "给飞书联系人发送消息",
@@ -280,6 +282,7 @@ async def feishu_send_message_to_contact(
         return ToolResult(title="发送异常", output=str(e), error=str(e))
 
 @tool(
+    hidden_params=["context"],
     display={
         "zh": {
             "name": "获取飞书联系人回复",
@@ -395,6 +398,7 @@ def _convert_feishu_msg_to_openai_content(client: FeishuClient, msg: Dict[str, A
     return blocks
 
 @tool(
+    hidden_params=["context"],
     display={
         "zh": {
             "name": "获取飞书聊天历史记录",

+ 1 - 1
agent/tools/builtin/file/edit.py

@@ -17,7 +17,7 @@ import re
 from agent.tools import tool, ToolResult, ToolContext
 
 
-@tool(description="编辑文件,使用精确字符串替换。支持多种智能匹配策略。")
+@tool(description="编辑文件,使用精确字符串替换。支持多种智能匹配策略。", hidden_params=["context"])
 async def edit_file(
     file_path: str,
     old_string: str,

+ 1 - 1
agent/tools/builtin/file/glob.py

@@ -19,7 +19,7 @@ from agent.tools import tool, ToolResult, ToolContext
 LIMIT = 100  # 最大返回数量(参考 opencode glob.ts:35)
 
 
-@tool(description="使用 glob 模式匹配文件")
+@tool(description="使用 glob 模式匹配文件", hidden_params=["context"])
 async def glob_files(
     pattern: str,
     path: Optional[str] = None,

+ 1 - 1
agent/tools/builtin/file/grep.py

@@ -21,7 +21,7 @@ LIMIT = 100  # 最大返回匹配数(参考 opencode grep.ts:107)
 MAX_LINE_LENGTH = 2000  # 最大行长度(参考 opencode grep.ts:10)
 
 
-@tool(description="在文件内容中搜索模式")
+@tool(description="在文件内容中搜索模式", hidden_params=["context"])
 async def grep_content(
     pattern: str,
     path: Optional[str] = None,

+ 1 - 1
agent/tools/builtin/file/read.py

@@ -27,7 +27,7 @@ MAX_LINE_LENGTH = 2000
 MAX_BYTES = 50 * 1024  # 50KB
 
 
-@tool(description="读取文件内容,支持文本文件、图片、PDF 等多种格式,也支持 HTTP/HTTPS URL")
+@tool(description="读取文件内容,支持文本文件、图片、PDF 等多种格式,也支持 HTTP/HTTPS URL", hidden_params=["context"])
 async def read_file(
     file_path: str,
     offset: int = 0,

+ 1 - 1
agent/tools/builtin/file/write.py

@@ -16,7 +16,7 @@ import difflib
 from agent.tools import tool, ToolResult, ToolContext
 
 
-@tool(description="写入文件内容(创建新文件、覆盖现有文件或追加内容)")
+@tool(description="写入文件内容(创建新文件、覆盖现有文件或追加内容)", hidden_params=["context"])
 async def write_file(
     file_path: str,
     content: str,

+ 1 - 1
agent/tools/builtin/glob_tool.py

@@ -19,7 +19,7 @@ from agent.tools import tool, ToolResult, ToolContext
 LIMIT = 100  # 最大返回数量(参考 opencode glob.ts:35)
 
 
-@tool(description="使用 glob 模式匹配文件")
+@tool(description="使用 glob 模式匹配文件", hidden_params=["context"])
 async def glob_files(
     pattern: str,
     path: Optional[str] = None,

+ 6 - 6
agent/tools/builtin/knowledge.py

@@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
 KNOWHUB_API = os.getenv("KNOWHUB_API", "http://localhost:8000")
 
 
-@tool()
+@tool(hidden_params=["context"])
 async def knowledge_search(
     query: str,
     top_k: int = 5,
@@ -95,7 +95,7 @@ async def knowledge_search(
         )
 
 
-@tool()
+@tool(hidden_params=["context"])
 async def knowledge_save(
     task: str,
     content: str,
@@ -187,7 +187,7 @@ async def knowledge_save(
         )
 
 
-@tool()
+@tool(hidden_params=["context"])
 async def knowledge_update(
     knowledge_id: str,
     add_helpful_case: Optional[Dict] = None,
@@ -257,7 +257,7 @@ async def knowledge_update(
         )
 
 
-@tool()
+@tool(hidden_params=["context"])
 async def knowledge_batch_update(
     feedback_list: List[Dict[str, Any]],
     context: Optional[ToolContext] = None,
@@ -307,7 +307,7 @@ async def knowledge_batch_update(
         )
 
 
-@tool()
+@tool(hidden_params=["context"])
 async def knowledge_list(
     limit: int = 10,
     types: Optional[List[str]] = None,
@@ -369,7 +369,7 @@ async def knowledge_list(
         )
 
 
-@tool()
+@tool(hidden_params=["context"])
 async def knowledge_slim(
     model: str = "google/gemini-2.0-flash-001",
     context: Optional[ToolContext] = None,

+ 3 - 0
agent/tools/builtin/sandbox.py

@@ -19,6 +19,7 @@ DEFAULT_TIMEOUT = 300.0
 
 
 @tool(
+    hidden_params=["context"],
     display={
         "zh": {
             "name": "创建沙盒环境",
@@ -120,6 +121,7 @@ async def sandbox_create_environment(
 
 
 @tool(
+    hidden_params=["context"],
     display={
         "zh": {
             "name": "执行沙盒命令",
@@ -234,6 +236,7 @@ async def sandbox_run_shell(
 
 
 @tool(
+    hidden_params=["context"],
     display={
         "zh": {
             "name": "重建沙盒端口",

+ 53 - 10
agent/tools/registry.py

@@ -66,7 +66,9 @@ class ToolRegistry:
 		requires_confirmation: bool = False,
 		editable_params: Optional[List[str]] = None,
 		display: Optional[Dict[str, Dict[str, Any]]] = None,
-		url_patterns: Optional[List[str]] = None
+		url_patterns: Optional[List[str]] = None,
+		hidden_params: Optional[List[str]] = None,
+		inject_params: Optional[Dict[str, Any]] = None
 	):
 		"""
 		注册工具
@@ -78,6 +80,8 @@ class ToolRegistry:
 			editable_params: 允许用户编辑的参数列表
 			display: i18n 展示信息 {"zh": {"name": "xx", "params": {...}}, "en": {...}}
 			url_patterns: URL 模式列表(如 ["*.google.com"],None = 无限制)
+			hidden_params: 隐藏参数列表(不生成 schema,LLM 看不到)
+			inject_params: 注入参数规则 {param_name: injector_func}
 		"""
 		func_name = func.__name__
 
@@ -85,7 +89,7 @@ class ToolRegistry:
 		if schema is None:
 			try:
 				from agent.tools.schema import SchemaGenerator
-				schema = SchemaGenerator.generate(func)
+				schema = SchemaGenerator.generate(func, hidden_params=hidden_params or [])
 			except Exception as e:
 				logger.error(f"Failed to generate schema for {func_name}: {e}")
 				raise
@@ -94,6 +98,8 @@ class ToolRegistry:
 			"func": func,
 			"schema": schema,
 			"url_patterns": url_patterns,
+			"hidden_params": hidden_params or [],
+			"inject_params": inject_params or {},
 			"ui_metadata": {
 				"requires_confirmation": requires_confirmation,
 				"editable_params": editable_params or [],
@@ -204,6 +210,7 @@ class ToolRegistry:
 
 		try:
 			func = self._tools[name]["func"]
+			tool_info = self._tools[name]
 
 			# 处理敏感数据占位符
 			if sensitive_data:
@@ -215,14 +222,34 @@ class ToolRegistry:
 			kwargs = {**arguments}
 			sig = inspect.signature(func)
 
-			# 注入 uid(如果函数接受)
-			if "uid" in sig.parameters:
+			# 注入隐藏参数(hidden_params)
+			hidden_params = tool_info.get("hidden_params", [])
+			if "uid" in hidden_params and "uid" in sig.parameters:
 				kwargs["uid"] = uid
-
-			# 注入 context(如果函数接受)
-			if "context" in sig.parameters:
+			if "context" in hidden_params and "context" in sig.parameters:
 				kwargs["context"] = context
 
+			# 注入默认值(inject_params)
+			inject_params = tool_info.get("inject_params", {})
+			for param_name, injector in inject_params.items():
+				if param_name in sig.parameters:
+					# 如果 LLM 已提供值,不覆盖
+					if param_name not in kwargs or kwargs[param_name] is None:
+						if callable(injector):
+							# 检查 injector 的参数数量
+							injector_sig = inspect.signature(injector)
+							if len(injector_sig.parameters) == 1:
+								# lambda ctx: ...
+								kwargs[param_name] = injector(context)
+							elif len(injector_sig.parameters) == 2:
+								# lambda ctx, args: ...
+								kwargs[param_name] = injector(context, kwargs)
+							else:
+								kwargs[param_name] = injector()
+						else:
+							# 直接使用值
+							kwargs[param_name] = injector
+
 			# 执行函数
 			if inspect.iscoroutinefunction(func):
 				result = await func(**kwargs)
@@ -415,7 +442,9 @@ def tool(
 	requires_confirmation: bool = False,
 	editable_params: Optional[List[str]] = None,
 	display: Optional[Dict[str, Dict[str, Any]]] = None,
-	url_patterns: Optional[List[str]] = None
+	url_patterns: Optional[List[str]] = None,
+	hidden_params: Optional[List[str]] = None,
+	inject_params: Optional[Dict[str, Any]] = None
 ):
 	"""
 	工具装饰器 - 自动注册工具并生成 Schema
@@ -427,9 +456,15 @@ def tool(
 		editable_params: 允许用户编辑的参数列表
 		display: i18n 展示信息
 		url_patterns: URL 模式列表(如 ["*.google.com"],None = 无限制)
+		hidden_params: 隐藏参数列表(不生成 schema,LLM 看不到)
+		inject_params: 注入参数规则 {param_name: injector_func}
 
 	Example:
 		@tool(
+			hidden_params=["context", "uid"],
+			inject_params={
+				"owner": lambda ctx: ctx.config.knowledge.get_owner(),
+			},
 			editable_params=["query"],
 			url_patterns=["*.google.com"],
 			display={
@@ -437,7 +472,13 @@ def tool(
 				"en": {"name": "Search Notes", "params": {"query": "Query"}}
 			}
 		)
-		async def search_blocks(query: str, limit: int = 10, uid: str = "") -> str:
+		async def search_blocks(
+			query: str,
+			limit: int = 10,
+			owner: Optional[str] = None,
+			context: Optional[ToolContext] = None,
+			uid: str = ""
+		) -> str:
 			'''搜索用户的笔记块'''
 			...
 	"""
@@ -448,7 +489,9 @@ def tool(
 			requires_confirmation=requires_confirmation,
 			editable_params=editable_params,
 			display=display,
-			url_patterns=url_patterns
+			url_patterns=url_patterns,
+			hidden_params=hidden_params,
+			inject_params=inject_params
 		)
 		return func
 

+ 7 - 4
agent/tools/schema.py

@@ -68,16 +68,19 @@ class SchemaGenerator:
     }
 
     @classmethod
-    def generate(cls, func: callable) -> Dict[str, Any]:
+    def generate(cls, func: callable, hidden_params: Optional[List[str]] = None) -> Dict[str, Any]:
         """
         从函数生成 OpenAI Tool Schema
 
         Args:
             func: 要生成 Schema 的函数
+            hidden_params: 隐藏参数列表(不生成 schema)
 
         Returns:
             OpenAI Tool Schema(JSON 格式)
         """
+        hidden_params = hidden_params or []
+
         # 解析函数签名
         sig = inspect.signature(func)
         func_name = func.__name__
@@ -98,11 +101,11 @@ class SchemaGenerator:
 
         for param_name, param in sig.parameters.items():
             # 跳过特殊参数
-            if param_name in ["self", "cls", "kwargs", "context"]:
+            if param_name in ["self", "cls", "kwargs"]:
                 continue
 
-            # 跳过 uid(由框架自动注入)
-            if param_name == "uid":
+            # 跳过隐藏参数
+            if param_name in hidden_params:
                 continue
 
             # 获取类型注解