Parcourir la source

git: fix msg list rebuild after compact

Talegorithm il y a 3 heures
Parent
commit
50835f0e57
2 fichiers modifiés avec 45 ajouts et 133 suppressions
  1. 44 132
      agent/core/runner.py
  2. 1 1
      examples/research/config.py

+ 44 - 132
agent/core/runner.py

@@ -906,15 +906,13 @@ class AgentRunner:
         if self.trace_store:
         if self.trace_store:
             await self.trace_store.add_message(summary_msg)
             await self.trace_store.add_message(summary_msg)
 
 
-        # 重建 history
-        system_msg = history[0] if history and history[0].get("role") == "system" else None
-        new_history = [system_msg, summary_msg.to_llm_dict()] if system_msg else [summary_msg.to_llm_dict()]
+        new_history = self._rebuild_history_after_compression(
+            history, summary_msg.to_llm_dict(), label="单次压缩"
+        )
 
 
         new_head_seq = sequence
         new_head_seq = sequence
         sequence += 1
         sequence += 1
 
 
-        logger.info(f"单次压缩完成: {len(history)} → {len(new_history)} 条消息")
-
         return new_history, new_head_seq, sequence
         return new_history, new_head_seq, sequence
 
 
     async def _agent_loop(
     async def _agent_loop(
@@ -1319,19 +1317,13 @@ class AgentRunner:
                         if self.trace_store:
                         if self.trace_store:
                             await self.trace_store.add_message(summary_msg)
                             await self.trace_store.add_message(summary_msg)
 
 
-                        # 重建 history
-                        if self.trace_store:
-                            main_path_messages = await self.trace_store.get_main_path_messages(
-                                trace_id, side_branch_ctx.start_head_seq
-                            )
-                            history = [m.to_llm_dict() for m in main_path_messages]
+                        history = self._rebuild_history_after_compression(
+                            history, summary_msg.to_llm_dict(), label="压缩侧分支"
+                        )
 
 
-                        history.append(summary_msg.to_llm_dict())
                         head_seq = sequence
                         head_seq = sequence
                         sequence += 1
                         sequence += 1
 
 
-                        logger.info(f"压缩侧分支完成,history 长度: {len(history)}")
-
                         # 清除侧分支队列
                         # 清除侧分支队列
                         config.force_side_branch = None
                         config.force_side_branch = None
 
 
@@ -1553,135 +1545,55 @@ class AgentRunner:
             if trace_obj:
             if trace_obj:
                 yield trace_obj
                 yield trace_obj
 
 
-    # ===== Level 2: LLM 压缩 =====
+    # ===== 压缩辅助方法 =====
 
 
-    async def _compress_history(
+    def _rebuild_history_after_compression(
         self,
         self,
-        trace_id: str,
         history: List[Dict],
         history: List[Dict],
-        goal_tree: Optional[GoalTree],
-        config: RunConfig,
-        sequence: int,
-        head_seq: int,
-    ) -> Tuple[List[Dict], int, int]:
+        summary_msg_dict: Dict,
+        label: str = "压缩",
+    ) -> List[Dict]:
         """
         """
-        Level 2 压缩:LLM 总结
+        压缩后重建 history:system prompt + 第一条 user message + summary
 
 
-        Step 1: 压缩总结 — LLM 生成 summary
-        Step 2: 存储 summary 为新消息,parent_sequence 跳到 system msg
-        Step 3: 重建 history
+        Args:
+            history: 压缩前的 history
+            summary_msg_dict: summary 消息的 LLM dict
+            label: 日志标签
 
 
         Returns:
         Returns:
-            (new_history, new_head_seq, next_sequence)
+            新的 history
         """
         """
-        logger.info("Level 2 压缩开始: trace=%s, 当前 history 长度=%d", trace_id, len(history))
-
-        # 找到 system message 的 sequence(主路径第一条消息)
-        system_msg_seq = None
-        system_msg_dict = None
-        if self.trace_store:
-            trace_obj = await self.trace_store.get_trace(trace_id)
-            if trace_obj and trace_obj.head_sequence > 0:
-                main_path = await self.trace_store.get_main_path_messages(
-                    trace_id, trace_obj.head_sequence
-                )
-                for msg in main_path:
-                    if msg.role == "system":
-                        system_msg_seq = msg.sequence
-                        system_msg_dict = msg.to_llm_dict()
-                        break
-
-        # Fallback: 从 history 中找 system message
-        if system_msg_dict is None:
-            for msg_dict in history:
-                if msg_dict.get("role") == "system":
-                    system_msg_dict = msg_dict
-                    break
-
-        if system_msg_dict is None:
-            logger.warning("Level 2 压缩跳过:未找到 system message")
-            return history, head_seq, sequence
-
-        # --- Step 1: 经验提取(reflect)---
-        try:
-            from agent.tools.builtin.knowledge import generate_and_save_reflection
-            await generate_and_save_reflection(
-                trace_id=trace_id,
-                messages=history,
-                llm_call_fn=self.llm_call,
-                model=config.model
-            )
-
-        except Exception as e:
-            logger.error(f"Level 2 经验提取失败: {e}")
-
-        # --- Step 2: 压缩总结 + 经验评估 ---
-        compress_prompt = build_compression_prompt(goal_tree)
-        compress_messages = list(history) + [{"role": "user", "content": compress_prompt}]
-
-        # 应用 Prompt Caching
-        compress_messages = self._add_cache_control(
-            compress_messages,
-            config.model,
-            config.enable_prompt_caching
-        )
-
-        compress_result = await self.llm_call(
-            messages=compress_messages,
-            model=config.model,
-            tools=[],
-            temperature=config.temperature,
-            **config.extra_llm_params,
-        )
-
-        raw_output = compress_result.get("content", "").strip()
-        if not raw_output:
-            logger.warning("Level 2 压缩跳过:LLM 未返回内容")
-            return history, head_seq, sequence
-
-        # 提取 [[SUMMARY]] 块
-        summary_text = raw_output
-        if "[[SUMMARY]]" in raw_output:
-            summary_text = raw_output[raw_output.index("[[SUMMARY]]") + len("[[SUMMARY]]"):].strip()
-
-        if not summary_text:
-            logger.warning("Level 2 压缩跳过:LLM 未返回 summary")
-            return history, head_seq, sequence
-
-        # --- Step 3: 存储 summary 消息 ---
-        summary_with_header = build_summary_header(summary_text)
-
-        summary_msg = Message.create(
-            trace_id=trace_id,
-            role="user",
-            sequence=sequence,
-            goal_id=None,
-            parent_sequence=system_msg_seq,  # 跳到 system msg,跳过所有中间消息
-            content=summary_with_header,
-        )
-
-        if self.trace_store:
-            await self.trace_store.add_message(summary_msg)
-
-        new_head_seq = sequence
-        sequence += 1
-
-        # --- Step 4: 重建 history ---
-        new_history = [system_msg_dict, summary_msg.to_llm_dict()]
+        system_msg = None
+        first_user_msg = None
+        for msg in history:
+            if msg.get("role") == "system" and not system_msg:
+                system_msg = msg
+            elif msg.get("role") == "user" and not first_user_msg:
+                first_user_msg = msg
+            if system_msg and first_user_msg:
+                break
 
 
-        # 更新 trace head_sequence
-        if self.trace_store:
-            await self.trace_store.update_trace(
-                trace_id,
-                head_sequence=new_head_seq,
-            )
+        new_history = []
+        if system_msg:
+            new_history.append(system_msg)
+        if first_user_msg:
+            new_history.append(first_user_msg)
+        new_history.append(summary_msg_dict)
 
 
-        logger.info(
-            "Level 2 压缩完成: 旧 history %d 条 → 新 history %d 条, summary 长度=%d",
-            len(history), len(new_history), len(summary_text),
-        )
+        logger.info(f"{label}完成: {len(history)} → {len(new_history)} 条消息")
+        for idx, msg in enumerate(new_history):
+            role = msg.get("role", "unknown")
+            content = msg.get("content", "")
+            if isinstance(content, str):
+                preview = content
+            elif isinstance(content, list):
+                preview = f"[{len(content)} blocks]"
+            else:
+                preview = str(content)
+            logger.info(f"  {label}后[{idx}] {role}: {preview}")
 
 
-        return new_history, new_head_seq, sequence
+        return new_history
 
 
     async def _run_reflect(
     async def _run_reflect(
         self,
         self,

+ 1 - 1
examples/research/config.py

@@ -35,7 +35,7 @@ RUN_CONFIG = RunConfig(
         owner="",  # 所有者(空则尝试从 git config user.email 获取,再空则用 agent:{agent_id})
         owner="",  # 所有者(空则尝试从 git config user.email 获取,再空则用 agent:{agent_id})
         default_tags={"project": "research", "domain": "ai_agent"},  # 默认 tags(会与工具调用参数合并)
         default_tags={"project": "research", "domain": "ai_agent"},  # 默认 tags(会与工具调用参数合并)
         default_scopes=["org:cybertogether"],  # 默认 scopes
         default_scopes=["org:cybertogether"],  # 默认 scopes
-        default_search_types=["strategy", "tool"],  # 默认搜索类型过滤
+        default_search_types=[],  # 默认搜索类型过滤
         default_search_owner=""  # 默认搜索 owner 过滤(空则不过滤)
         default_search_owner=""  # 默认搜索 owner 过滤(空则不过滤)
     )
     )
 )
 )