|
|
@@ -267,6 +267,55 @@ class AgentRunner:
|
|
|
# 知识保存跟踪(每个 trace 独立)
|
|
|
self._saved_knowledge_ids: Dict[str, List[str]] = {} # trace_id → [knowledge_ids]
|
|
|
|
|
|
+ # 知识确认等待(每个 trace 独立)
|
|
|
+ self._pending_confirmations: Dict[str, asyncio.Future] = {} # trace_id → Future
|
|
|
+
|
|
|
+ async def _wait_for_confirmation(self, trace_id: str, confirm_type: str, data: Dict[str, Any], timeout: float = 300) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 广播知识确认请求并暂停等待前端响应。
|
|
|
+
|
|
|
+ Args:
|
|
|
+ trace_id: Trace ID
|
|
|
+ confirm_type: "knowledge_injection" | "knowledge_save"
|
|
|
+ data: 确认请求的数据(知识内容等)
|
|
|
+ timeout: 超时秒数,默认 300 秒
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 用户响应 {"action": "confirm"|"reject", "edited_args"?: {...}}
|
|
|
+ """
|
|
|
+ from agent.trace.websocket import broadcast_knowledge_confirm_request
|
|
|
+ await broadcast_knowledge_confirm_request(trace_id, confirm_type, data)
|
|
|
+
|
|
|
+ loop = asyncio.get_event_loop()
|
|
|
+ future = loop.create_future()
|
|
|
+ self._pending_confirmations[trace_id] = future
|
|
|
+
|
|
|
+ try:
|
|
|
+ result = await asyncio.wait_for(future, timeout=timeout)
|
|
|
+ return result
|
|
|
+ except asyncio.TimeoutError:
|
|
|
+ logger.warning(f"[Knowledge Confirm] 等待确认超时 ({timeout}s),默认确认: trace={trace_id}, type={confirm_type}")
|
|
|
+ return {"action": "confirm"}
|
|
|
+ finally:
|
|
|
+ self._pending_confirmations.pop(trace_id, None)
|
|
|
+
|
|
|
+ async def resolve_confirmation(self, trace_id: str, response: Dict[str, Any]) -> bool:
|
|
|
+ """
|
|
|
+ 前端调用,resolve 等待中的确认 Future。
|
|
|
+
|
|
|
+ Args:
|
|
|
+ trace_id: Trace ID
|
|
|
+ response: {"action": "confirm"|"reject", "edited_args"?: {...}}
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ True 如果成功 resolve,False 如果没有等待中的确认
|
|
|
+ """
|
|
|
+ future = self._pending_confirmations.get(trace_id)
|
|
|
+ if future is None or future.done():
|
|
|
+ return False
|
|
|
+ future.set_result(response)
|
|
|
+ return True
|
|
|
+
|
|
|
# ===== 核心公开方法 =====
|
|
|
|
|
|
async def run(
|
|
|
@@ -411,6 +460,10 @@ class AgentRunner:
|
|
|
if cancel_event is None:
|
|
|
return False
|
|
|
cancel_event.set()
|
|
|
+ # 如果有等待中的知识确认,立即 resolve 以解除阻塞
|
|
|
+ future = self._pending_confirmations.get(trace_id)
|
|
|
+ if future and not future.done():
|
|
|
+ future.set_result({"action": "confirm"})
|
|
|
return True
|
|
|
|
|
|
# ===== 单次调用(保留)=====
|
|
|
@@ -561,6 +614,12 @@ class AgentRunner:
|
|
|
completed_at=None,
|
|
|
)
|
|
|
trace_obj.status = "running"
|
|
|
+ # 广播状态变化给前端
|
|
|
+ try:
|
|
|
+ from agent.trace.websocket import broadcast_trace_status_changed
|
|
|
+ await broadcast_trace_status_changed(config.trace_id, "running")
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
|
|
|
return trace_obj, goal_tree, sequence
|
|
|
|
|
|
@@ -862,6 +921,13 @@ class AgentRunner:
|
|
|
_cached_exp_text = ""
|
|
|
|
|
|
for iteration in range(config.max_iterations):
|
|
|
+ # 更新活动时间(表明trace正在活跃运行)
|
|
|
+ if self.trace_store:
|
|
|
+ await self.trace_store.update_trace(
|
|
|
+ trace_id,
|
|
|
+ last_activity_at=datetime.now()
|
|
|
+ )
|
|
|
+
|
|
|
# 检查取消信号
|
|
|
cancel_event = self._cancel_events.get(trace_id)
|
|
|
if cancel_event and cancel_event.is_set():
|
|
|
@@ -873,6 +939,12 @@ class AgentRunner:
|
|
|
head_sequence=head_seq,
|
|
|
completed_at=datetime.now(),
|
|
|
)
|
|
|
+ # 广播状态变化给前端
|
|
|
+ try:
|
|
|
+ from agent.trace.websocket import broadcast_trace_status_changed
|
|
|
+ await broadcast_trace_status_changed(trace_id, "stopped")
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
trace_obj = await self.trace_store.get_trace(trace_id)
|
|
|
if trace_obj:
|
|
|
yield trace_obj
|
|
|
@@ -1028,21 +1100,41 @@ class AgentRunner:
|
|
|
context={"runner": self},
|
|
|
)
|
|
|
if relevant_exps:
|
|
|
- # 保存到 goal 对象
|
|
|
- current_goal.knowledge = relevant_exps
|
|
|
- logger.info(f"[Knowledge Injection] 已将 {len(relevant_exps)} 条知识注入到 goal {current_goal.id}: {current_goal.description[:40]}")
|
|
|
- logger.debug(f"[Knowledge Injection] 注入的知识 IDs: {[exp.get('id') for exp in relevant_exps]}")
|
|
|
- # 持久化保存 goal_tree
|
|
|
- await self.trace_store.update_goal_tree(trace_id, goal_tree)
|
|
|
- self.used_ex_ids = [exp['id'] for exp in relevant_exps]
|
|
|
- parts = [f"[{exp['id']}] {exp['content']}" for exp in relevant_exps]
|
|
|
- _cached_exp_text = "## 参考历史经验\n" + "\n\n".join(parts)
|
|
|
- logger.info(
|
|
|
- "经验检索: goal='%s', 命中 %d 条 %s",
|
|
|
- current_goal.description[:40],
|
|
|
- len(relevant_exps),
|
|
|
- self.used_ex_ids,
|
|
|
- )
|
|
|
+ # 暂停等待用户确认知识注入
|
|
|
+ try:
|
|
|
+ confirm_result = await self._wait_for_confirmation(
|
|
|
+ trace_id, "knowledge_injection", {
|
|
|
+ "goal_id": current_goal.id,
|
|
|
+ "goal_description": current_goal.description,
|
|
|
+ "knowledge_items": relevant_exps,
|
|
|
+ }
|
|
|
+ )
|
|
|
+ if confirm_result.get("action") == "reject":
|
|
|
+ logger.info(f"[Knowledge Injection] 用户拒绝注入知识到 goal {current_goal.id}")
|
|
|
+ relevant_exps = []
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"[Knowledge Injection] 确认流程异常,默认注入: {e}")
|
|
|
+
|
|
|
+ if relevant_exps:
|
|
|
+ # 保存到 goal 对象
|
|
|
+ current_goal.knowledge = relevant_exps
|
|
|
+ logger.info(f"[Knowledge Injection] 已将 {len(relevant_exps)} 条知识注入到 goal {current_goal.id}: {current_goal.description[:40]}")
|
|
|
+ logger.debug(f"[Knowledge Injection] 注入的知识 IDs: {[exp.get('id') for exp in relevant_exps]}")
|
|
|
+ # 持久化保存 goal_tree
|
|
|
+ await self.trace_store.update_goal_tree(trace_id, goal_tree)
|
|
|
+ self.used_ex_ids = [exp['id'] for exp in relevant_exps]
|
|
|
+ parts = [f"[{exp['id']}] {exp['content']}" for exp in relevant_exps]
|
|
|
+ _cached_exp_text = "## 参考历史经验\n" + "\n\n".join(parts)
|
|
|
+ logger.info(
|
|
|
+ "经验检索: goal='%s', 命中 %d 条 %s",
|
|
|
+ current_goal.description[:40],
|
|
|
+ len(relevant_exps),
|
|
|
+ self.used_ex_ids,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ current_goal.knowledge = []
|
|
|
+ await self.trace_store.update_goal_tree(trace_id, goal_tree)
|
|
|
+ _cached_exp_text = ""
|
|
|
else:
|
|
|
current_goal.knowledge = []
|
|
|
logger.info(f"[Knowledge Injection] goal {current_goal.id} 未找到相关知识")
|
|
|
@@ -1122,43 +1214,14 @@ class AgentRunner:
|
|
|
sequence += 1
|
|
|
|
|
|
|
|
|
- # 调用 LLM(同时监听 cancel 信号,stop 时能立即中断)
|
|
|
- cancel_event = self._cancel_events.get(trace_id)
|
|
|
- llm_task = asyncio.create_task(
|
|
|
- self.llm_call(
|
|
|
- messages=llm_messages,
|
|
|
- model=config.model,
|
|
|
- tools=tool_schemas,
|
|
|
- temperature=config.temperature,
|
|
|
- **config.extra_llm_params,
|
|
|
- )
|
|
|
+ # 调用 LLM(等待完成后再检查 cancel 信号,不中断正在进行的调用)
|
|
|
+ result = await self.llm_call(
|
|
|
+ messages=llm_messages,
|
|
|
+ model=config.model,
|
|
|
+ tools=tool_schemas,
|
|
|
+ temperature=config.temperature,
|
|
|
+ **config.extra_llm_params,
|
|
|
)
|
|
|
- if cancel_event:
|
|
|
- cancel_wait = asyncio.create_task(cancel_event.wait())
|
|
|
- done, pending = await asyncio.wait(
|
|
|
- {llm_task, cancel_wait},
|
|
|
- return_when=asyncio.FIRST_COMPLETED,
|
|
|
- )
|
|
|
- for t in pending:
|
|
|
- t.cancel()
|
|
|
- if cancel_wait in done:
|
|
|
- # cancel 先触发,直接停止
|
|
|
- llm_task.cancel()
|
|
|
- logger.info(f"Trace {trace_id} cancelled during LLM call")
|
|
|
- if self.trace_store:
|
|
|
- await self.trace_store.update_trace(
|
|
|
- trace_id,
|
|
|
- status="stopped",
|
|
|
- head_sequence=head_seq,
|
|
|
- completed_at=datetime.now(),
|
|
|
- )
|
|
|
- trace_obj = await self.trace_store.get_trace(trace_id)
|
|
|
- if trace_obj:
|
|
|
- yield trace_obj
|
|
|
- return
|
|
|
- result = llm_task.result()
|
|
|
- else:
|
|
|
- result = await llm_task
|
|
|
|
|
|
response_content = result.get("content", "")
|
|
|
tool_calls = result.get("tool_calls")
|
|
|
@@ -1300,6 +1363,51 @@ class AgentRunner:
|
|
|
elif tool_args is None:
|
|
|
tool_args = {}
|
|
|
|
|
|
+ # save_knowledge 暂停确认:在实际执行前等待用户确认
|
|
|
+ if tool_name == "save_knowledge":
|
|
|
+ try:
|
|
|
+ confirm_result = await self._wait_for_confirmation(
|
|
|
+ trace_id, "knowledge_save", {
|
|
|
+ "tool_args": tool_args,
|
|
|
+ }
|
|
|
+ )
|
|
|
+ if confirm_result.get("action") == "reject":
|
|
|
+ logger.info(f"[Knowledge Save] 用户拒绝保存知识: {tool_args.get('scenario', '')[:40]}")
|
|
|
+ tool_result = {"text": "用户拒绝保存此知识。"}
|
|
|
+ # 跳过实际执行,直接构造 tool result
|
|
|
+ tool_text = tool_result["text"]
|
|
|
+ tool_images = []
|
|
|
+ tool_usage = None
|
|
|
+
|
|
|
+ history.append({
|
|
|
+ "role": "tool",
|
|
|
+ "tool_call_id": tc["id"],
|
|
|
+ "content": tool_text,
|
|
|
+ })
|
|
|
+
|
|
|
+ if self.trace_store:
|
|
|
+ tool_message = Message.create(
|
|
|
+ trace_id=trace_id,
|
|
|
+ role="tool",
|
|
|
+ sequence=sequence,
|
|
|
+ goal_id=current_goal_id,
|
|
|
+ parent_sequence=head_seq if head_seq > 0 else None,
|
|
|
+ content=tool_text,
|
|
|
+ tool_call_id=tc["id"],
|
|
|
+ tool_name=tool_name,
|
|
|
+ )
|
|
|
+ await self.trace_store.add_message(tool_message)
|
|
|
+ yield tool_message
|
|
|
+ head_seq = sequence
|
|
|
+ sequence += 1
|
|
|
+ continue
|
|
|
+ elif confirm_result.get("edited_args"):
|
|
|
+ # 用户编辑了内容
|
|
|
+ tool_args.update(confirm_result["edited_args"])
|
|
|
+ logger.info(f"[Knowledge Save] 用户编辑了知识内容后确认保存")
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"[Knowledge Save] 确认流程异常,默认执行: {e}")
|
|
|
+
|
|
|
tool_result = await self.tools.execute(
|
|
|
tool_name,
|
|
|
tool_args,
|