|
@@ -73,9 +73,8 @@ class SideBranchContext:
|
|
|
start_head_seq: int # 侧分支起点的 head_seq
|
|
start_head_seq: int # 侧分支起点的 head_seq
|
|
|
start_sequence: int # 侧分支第一条消息的 sequence
|
|
start_sequence: int # 侧分支第一条消息的 sequence
|
|
|
start_history_length: int # 侧分支起点的 history 长度
|
|
start_history_length: int # 侧分支起点的 history 长度
|
|
|
- side_messages: List[Message] # 侧分支产生的消息
|
|
|
|
|
|
|
+ start_iteration: int # 侧分支开始时的 iteration
|
|
|
max_turns: int = 5 # 最大轮次
|
|
max_turns: int = 5 # 最大轮次
|
|
|
- current_turn: int = 0 # 当前轮次
|
|
|
|
|
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
|
"""转换为字典(用于持久化和传递给工具)"""
|
|
"""转换为字典(用于持久化和传递给工具)"""
|
|
@@ -84,8 +83,8 @@ class SideBranchContext:
|
|
|
"branch_id": self.branch_id,
|
|
"branch_id": self.branch_id,
|
|
|
"start_head_seq": self.start_head_seq,
|
|
"start_head_seq": self.start_head_seq,
|
|
|
"start_sequence": self.start_sequence,
|
|
"start_sequence": self.start_sequence,
|
|
|
|
|
+ "start_iteration": self.start_iteration,
|
|
|
"max_turns": self.max_turns,
|
|
"max_turns": self.max_turns,
|
|
|
- "current_turn": self.current_turn,
|
|
|
|
|
"is_side_branch": True,
|
|
"is_side_branch": True,
|
|
|
"started_at": datetime.now().isoformat(),
|
|
"started_at": datetime.now().isoformat(),
|
|
|
}
|
|
}
|
|
@@ -107,6 +106,9 @@ class RunConfig:
|
|
|
tools: Optional[List[str]] = None # None = 全部已注册工具
|
|
tools: Optional[List[str]] = None # None = 全部已注册工具
|
|
|
side_branch_max_turns: int = 5 # 侧分支最大轮次(压缩/反思)
|
|
side_branch_max_turns: int = 5 # 侧分支最大轮次(压缩/反思)
|
|
|
|
|
|
|
|
|
|
+ # --- 强制侧分支(用于 API 手动触发)---
|
|
|
|
|
+ force_side_branch: Optional[Literal["compression", "reflection"]] = None
|
|
|
|
|
+
|
|
|
# --- 框架层参数 ---
|
|
# --- 框架层参数 ---
|
|
|
agent_type: str = "default"
|
|
agent_type: str = "default"
|
|
|
uid: Optional[str] = None
|
|
uid: Optional[str] = None
|
|
@@ -310,27 +312,17 @@ class AgentRunner:
|
|
|
side_branch_ctx_for_build: Optional[SideBranchContext] = None
|
|
side_branch_ctx_for_build: Optional[SideBranchContext] = None
|
|
|
if trace.context.get("active_side_branch") and messages:
|
|
if trace.context.get("active_side_branch") and messages:
|
|
|
side_branch_data = trace.context["active_side_branch"]
|
|
side_branch_data = trace.context["active_side_branch"]
|
|
|
- branch_id = side_branch_data["branch_id"]
|
|
|
|
|
|
|
|
|
|
- # 从数据库查询侧分支消息
|
|
|
|
|
- if self.trace_store:
|
|
|
|
|
- all_messages = await self.trace_store.get_trace_messages(trace.trace_id)
|
|
|
|
|
- side_messages = [
|
|
|
|
|
- m for m in all_messages
|
|
|
|
|
- if m.branch_id == branch_id
|
|
|
|
|
- ]
|
|
|
|
|
-
|
|
|
|
|
- # 创建侧分支上下文(用于标记用户追加的消息)
|
|
|
|
|
- side_branch_ctx_for_build = SideBranchContext(
|
|
|
|
|
- type=side_branch_data["type"],
|
|
|
|
|
- branch_id=branch_id,
|
|
|
|
|
- start_head_seq=side_branch_data["start_head_seq"],
|
|
|
|
|
- start_sequence=side_branch_data["start_sequence"],
|
|
|
|
|
- start_history_length=0,
|
|
|
|
|
- side_messages=side_messages,
|
|
|
|
|
- max_turns=side_branch_data.get("max_turns", config.side_branch_max_turns),
|
|
|
|
|
- current_turn=side_branch_data.get("current_turn", 0),
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ # 创建侧分支上下文(用于标记用户追加的消息)
|
|
|
|
|
+ side_branch_ctx_for_build = SideBranchContext(
|
|
|
|
|
+ type=side_branch_data["type"],
|
|
|
|
|
+ branch_id=side_branch_data["branch_id"],
|
|
|
|
|
+ start_head_seq=side_branch_data["start_head_seq"],
|
|
|
|
|
+ start_sequence=side_branch_data["start_sequence"],
|
|
|
|
|
+ start_history_length=0,
|
|
|
|
|
+ start_iteration=side_branch_data.get("start_iteration", 0),
|
|
|
|
|
+ max_turns=side_branch_data.get("max_turns", config.side_branch_max_turns),
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
# Phase 2: BUILD HISTORY
|
|
# Phase 2: BUILD HISTORY
|
|
|
history, sequence, created_messages, head_seq = await self._build_history(
|
|
history, sequence, created_messages, head_seq = await self._build_history(
|
|
@@ -698,7 +690,6 @@ class AgentRunner:
|
|
|
branch_id=side_branch_ctx.branch_id,
|
|
branch_id=side_branch_ctx.branch_id,
|
|
|
content=msg_dict.get("content"),
|
|
content=msg_dict.get("content"),
|
|
|
)
|
|
)
|
|
|
- side_branch_ctx.side_messages.append(stored_msg)
|
|
|
|
|
logger.info(f"用户在侧分支 {side_branch_ctx.type} 中追加消息")
|
|
logger.info(f"用户在侧分支 {side_branch_ctx.type} 中追加消息")
|
|
|
else:
|
|
else:
|
|
|
stored_msg = Message.from_llm_dict(
|
|
stored_msg = Message.from_llm_dict(
|
|
@@ -949,14 +940,13 @@ class AgentRunner:
|
|
|
start_head_seq=side_branch_data["start_head_seq"],
|
|
start_head_seq=side_branch_data["start_head_seq"],
|
|
|
start_sequence=side_branch_data["start_sequence"],
|
|
start_sequence=side_branch_data["start_sequence"],
|
|
|
start_history_length=0, # 稍后重新计算
|
|
start_history_length=0, # 稍后重新计算
|
|
|
- side_messages=side_messages,
|
|
|
|
|
|
|
+ start_iteration=side_branch_data.get("start_iteration", 0),
|
|
|
max_turns=side_branch_data.get("max_turns", config.side_branch_max_turns),
|
|
max_turns=side_branch_data.get("max_turns", config.side_branch_max_turns),
|
|
|
- current_turn=side_branch_data.get("current_turn", 0),
|
|
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
logger.info(
|
|
logger.info(
|
|
|
f"恢复未完成的侧分支: {side_branch_ctx.type}, "
|
|
f"恢复未完成的侧分支: {side_branch_ctx.type}, "
|
|
|
- f"已执行 {side_branch_ctx.current_turn}/{side_branch_ctx.max_turns} 轮"
|
|
|
|
|
|
|
+ f"max_turns={side_branch_ctx.max_turns}"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
# 将侧分支消息追加到 history
|
|
# 将侧分支消息追加到 history
|
|
@@ -999,14 +989,29 @@ class AgentRunner:
|
|
|
# Context 管理(仅主路径)
|
|
# Context 管理(仅主路径)
|
|
|
needs_enter_side_branch = False
|
|
needs_enter_side_branch = False
|
|
|
if not side_branch_ctx:
|
|
if not side_branch_ctx:
|
|
|
- history, head_seq, sequence, needs_enter_side_branch = await self._manage_context_usage(
|
|
|
|
|
- trace_id, history, goal_tree, config, sequence, head_seq
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ # 检查是否强制进入侧分支(API 手动触发)
|
|
|
|
|
+ if config.force_side_branch:
|
|
|
|
|
+ needs_enter_side_branch = True
|
|
|
|
|
+ logger.info(f"强制进入侧分支: {config.force_side_branch}")
|
|
|
|
|
+ else:
|
|
|
|
|
+ # 正常的 context 管理逻辑
|
|
|
|
|
+ history, head_seq, sequence, needs_enter_side_branch = await self._manage_context_usage(
|
|
|
|
|
+ trace_id, history, goal_tree, config, sequence, head_seq
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
# 进入侧分支
|
|
# 进入侧分支
|
|
|
if needs_enter_side_branch and not side_branch_ctx:
|
|
if needs_enter_side_branch and not side_branch_ctx:
|
|
|
- # 判断侧分支类型:反思 or 压缩
|
|
|
|
|
- branch_type = "reflection" if config.knowledge.enable_extraction else "compression"
|
|
|
|
|
|
|
+ # 判断侧分支类型
|
|
|
|
|
+ if config.force_side_branch:
|
|
|
|
|
+ # API 强制触发
|
|
|
|
|
+ branch_type = config.force_side_branch
|
|
|
|
|
+ elif config.knowledge.enable_extraction:
|
|
|
|
|
+ # 自动触发:反思
|
|
|
|
|
+ branch_type = "reflection"
|
|
|
|
|
+ else:
|
|
|
|
|
+ # 自动触发:压缩
|
|
|
|
|
+ branch_type = "compression"
|
|
|
|
|
+
|
|
|
branch_id = f"{branch_type}_{uuid.uuid4().hex[:8]}"
|
|
branch_id = f"{branch_type}_{uuid.uuid4().hex[:8]}"
|
|
|
|
|
|
|
|
side_branch_ctx = SideBranchContext(
|
|
side_branch_ctx = SideBranchContext(
|
|
@@ -1015,9 +1020,8 @@ class AgentRunner:
|
|
|
start_head_seq=head_seq,
|
|
start_head_seq=head_seq,
|
|
|
start_sequence=sequence,
|
|
start_sequence=sequence,
|
|
|
start_history_length=len(history),
|
|
start_history_length=len(history),
|
|
|
- side_messages=[],
|
|
|
|
|
|
|
+ start_iteration=iteration,
|
|
|
max_turns=config.side_branch_max_turns,
|
|
max_turns=config.side_branch_max_turns,
|
|
|
- current_turn=0,
|
|
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
# 持久化侧分支状态
|
|
# 持久化侧分支状态
|
|
@@ -1027,8 +1031,8 @@ class AgentRunner:
|
|
|
"branch_id": side_branch_ctx.branch_id,
|
|
"branch_id": side_branch_ctx.branch_id,
|
|
|
"start_head_seq": side_branch_ctx.start_head_seq,
|
|
"start_head_seq": side_branch_ctx.start_head_seq,
|
|
|
"start_sequence": side_branch_ctx.start_sequence,
|
|
"start_sequence": side_branch_ctx.start_sequence,
|
|
|
|
|
+ "start_iteration": side_branch_ctx.start_iteration,
|
|
|
"max_turns": side_branch_ctx.max_turns,
|
|
"max_turns": side_branch_ctx.max_turns,
|
|
|
- "current_turn": 0,
|
|
|
|
|
"started_at": datetime.now().isoformat(),
|
|
"started_at": datetime.now().isoformat(),
|
|
|
}
|
|
}
|
|
|
await self.trace_store.update_trace(
|
|
await self.trace_store.update_trace(
|
|
@@ -1058,7 +1062,6 @@ class AgentRunner:
|
|
|
await self.trace_store.add_message(branch_user_msg)
|
|
await self.trace_store.add_message(branch_user_msg)
|
|
|
|
|
|
|
|
history.append(branch_user_msg.to_llm_dict())
|
|
history.append(branch_user_msg.to_llm_dict())
|
|
|
- side_branch_ctx.side_messages.append(branch_user_msg)
|
|
|
|
|
head_seq = sequence
|
|
head_seq = sequence
|
|
|
sequence += 1
|
|
sequence += 1
|
|
|
|
|
|
|
@@ -1174,9 +1177,7 @@ class AgentRunner:
|
|
|
cache_read_tokens=cache_read_tokens or 0,
|
|
cache_read_tokens=cache_read_tokens or 0,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- # 如果在侧分支,记录到 side_messages
|
|
|
|
|
- if side_branch_ctx:
|
|
|
|
|
- side_branch_ctx.side_messages.append(assistant_msg)
|
|
|
|
|
|
|
+ # 如果在侧分支,记录到 assistant_msg(已持久化,不需要额外维护)
|
|
|
|
|
|
|
|
yield assistant_msg
|
|
yield assistant_msg
|
|
|
head_seq = sequence
|
|
head_seq = sequence
|
|
@@ -1184,18 +1185,11 @@ class AgentRunner:
|
|
|
|
|
|
|
|
# 检查侧分支是否应该退出
|
|
# 检查侧分支是否应该退出
|
|
|
if side_branch_ctx:
|
|
if side_branch_ctx:
|
|
|
- side_branch_ctx.current_turn += 1
|
|
|
|
|
-
|
|
|
|
|
- # 更新持久化状态
|
|
|
|
|
- if self.trace_store:
|
|
|
|
|
- trace.context["active_side_branch"]["current_turn"] = side_branch_ctx.current_turn
|
|
|
|
|
- await self.trace_store.update_trace(
|
|
|
|
|
- trace_id,
|
|
|
|
|
- context=trace.context
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ # 计算侧分支已执行的轮次
|
|
|
|
|
+ turns_in_branch = iteration - side_branch_ctx.start_iteration
|
|
|
|
|
|
|
|
# 检查是否达到最大轮次
|
|
# 检查是否达到最大轮次
|
|
|
- if side_branch_ctx.current_turn >= side_branch_ctx.max_turns:
|
|
|
|
|
|
|
+ if turns_in_branch >= side_branch_ctx.max_turns:
|
|
|
logger.warning(
|
|
logger.warning(
|
|
|
f"侧分支 {side_branch_ctx.type} 达到最大轮次 "
|
|
f"侧分支 {side_branch_ctx.type} 达到最大轮次 "
|
|
|
f"{side_branch_ctx.max_turns},强制退出"
|
|
f"{side_branch_ctx.max_turns},强制退出"
|
|
@@ -1225,6 +1219,9 @@ class AgentRunner:
|
|
|
side_branch_ctx.start_head_seq
|
|
side_branch_ctx.start_head_seq
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
+ # 清除强制侧分支配置
|
|
|
|
|
+ config.force_side_branch = None
|
|
|
|
|
+
|
|
|
side_branch_ctx = None
|
|
side_branch_ctx = None
|
|
|
continue
|
|
continue
|
|
|
|
|
|
|
@@ -1247,6 +1244,9 @@ class AgentRunner:
|
|
|
history = [m.to_llm_dict() for m in main_path_messages]
|
|
history = [m.to_llm_dict() for m in main_path_messages]
|
|
|
head_seq = side_branch_ctx.start_head_seq
|
|
head_seq = side_branch_ctx.start_head_seq
|
|
|
|
|
|
|
|
|
|
+ # 清除强制侧分支配置
|
|
|
|
|
+ config.force_side_branch = None
|
|
|
|
|
+
|
|
|
side_branch_ctx = None
|
|
side_branch_ctx = None
|
|
|
continue
|
|
continue
|
|
|
|
|
|
|
@@ -1256,16 +1256,23 @@ class AgentRunner:
|
|
|
|
|
|
|
|
# 提取结果
|
|
# 提取结果
|
|
|
if side_branch_ctx.type == "compression":
|
|
if side_branch_ctx.type == "compression":
|
|
|
- # 从侧分支消息中提取 summary
|
|
|
|
|
|
|
+ # 从数据库查询侧分支消息并提取 summary
|
|
|
summary_text = ""
|
|
summary_text = ""
|
|
|
- for msg in side_branch_ctx.side_messages:
|
|
|
|
|
- if msg.role == "assistant" and isinstance(msg.content, dict):
|
|
|
|
|
- text = msg.content.get("text", "")
|
|
|
|
|
- if "[[SUMMARY]]" in text:
|
|
|
|
|
- summary_text = text[text.index("[[SUMMARY]]") + len("[[SUMMARY]]"):].strip()
|
|
|
|
|
- break
|
|
|
|
|
- elif text:
|
|
|
|
|
- summary_text = text
|
|
|
|
|
|
|
+ if self.trace_store:
|
|
|
|
|
+ all_messages = await self.trace_store.get_trace_messages(trace_id)
|
|
|
|
|
+ side_messages = [
|
|
|
|
|
+ m for m in all_messages
|
|
|
|
|
+ if m.branch_id == side_branch_ctx.branch_id
|
|
|
|
|
+ ]
|
|
|
|
|
+
|
|
|
|
|
+ for msg in side_messages:
|
|
|
|
|
+ if msg.role == "assistant" and isinstance(msg.content, dict):
|
|
|
|
|
+ text = msg.content.get("text", "")
|
|
|
|
|
+ if "[[SUMMARY]]" in text:
|
|
|
|
|
+ summary_text = text[text.index("[[SUMMARY]]") + len("[[SUMMARY]]"):].strip()
|
|
|
|
|
+ break
|
|
|
|
|
+ elif text:
|
|
|
|
|
+ summary_text = text
|
|
|
|
|
|
|
|
if not summary_text:
|
|
if not summary_text:
|
|
|
logger.warning("侧分支未生成有效 summary,使用默认")
|
|
logger.warning("侧分支未生成有效 summary,使用默认")
|
|
@@ -1317,6 +1324,9 @@ class AgentRunner:
|
|
|
head_sequence=head_seq,
|
|
head_sequence=head_seq,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
+ # 清除强制侧分支配置(避免影响后续续跑)
|
|
|
|
|
+ config.force_side_branch = None
|
|
|
|
|
+
|
|
|
side_branch_ctx = None
|
|
side_branch_ctx = None
|
|
|
continue
|
|
continue
|
|
|
|
|
|
|
@@ -1382,7 +1392,6 @@ class AgentRunner:
|
|
|
"type": side_branch_ctx.type,
|
|
"type": side_branch_ctx.type,
|
|
|
"branch_id": side_branch_ctx.branch_id,
|
|
"branch_id": side_branch_ctx.branch_id,
|
|
|
"is_side_branch": True,
|
|
"is_side_branch": True,
|
|
|
- "current_turn": side_branch_ctx.current_turn,
|
|
|
|
|
"max_turns": side_branch_ctx.max_turns,
|
|
"max_turns": side_branch_ctx.max_turns,
|
|
|
} if side_branch_ctx else None,
|
|
} if side_branch_ctx else None,
|
|
|
},
|
|
},
|
|
@@ -1469,9 +1478,7 @@ class AgentRunner:
|
|
|
print(f"[Runner] 截图已保存: {png_path.name}")
|
|
print(f"[Runner] 截图已保存: {png_path.name}")
|
|
|
break # 只存第一张
|
|
break # 只存第一张
|
|
|
|
|
|
|
|
- # 如果在侧分支,记录到 side_messages
|
|
|
|
|
- if side_branch_ctx:
|
|
|
|
|
- side_branch_ctx.side_messages.append(tool_msg)
|
|
|
|
|
|
|
+ # 如果在侧分支,tool_msg 已持久化(不需要额外维护)
|
|
|
|
|
|
|
|
yield tool_msg
|
|
yield tool_msg
|
|
|
head_seq = sequence
|
|
head_seq = sequence
|