소스 검색

feat: side branch mode in runner

Talegorithm 9 시간 전
부모
커밋
b988c96812
8개의 변경된 파일786개의 추가작업 그리고 123개의 파일을 삭제
  1. 531 109
      agent/core/runner.py
  2. 16 7
      agent/docs/architecture.md
  3. 149 0
      agent/docs/decisions.md
  4. 14 0
      agent/trace/models.py
  5. 6 2
      agent/trace/store.py
  6. 10 4
      knowhub/embeddings.py
  7. 1 1
      knowhub/server.py
  8. 59 0
      test_embeddings.py

+ 531 - 109
agent/core/runner.py

@@ -65,6 +65,32 @@ class ContextUsage:
     image_count: int = 0
     image_count: int = 0
 
 
 
 
+@dataclass
+class SideBranchContext:
+    """侧分支上下文(压缩/反思)"""
+    type: Literal["compression", "reflection"]
+    branch_id: str
+    start_head_seq: int          # 侧分支起点的 head_seq
+    start_sequence: int          # 侧分支第一条消息的 sequence
+    start_history_length: int    # 侧分支起点的 history 长度
+    side_messages: List[Message] # 侧分支产生的消息
+    max_turns: int = 5           # 最大轮次
+    current_turn: int = 0        # 当前轮次
+
+    def to_dict(self) -> Dict[str, Any]:
+        """转换为字典(用于持久化和传递给工具)"""
+        return {
+            "type": self.type,
+            "branch_id": self.branch_id,
+            "start_head_seq": self.start_head_seq,
+            "start_sequence": self.start_sequence,
+            "max_turns": self.max_turns,
+            "current_turn": self.current_turn,
+            "is_side_branch": True,
+            "started_at": datetime.now().isoformat(),
+        }
+
+
 # ===== 运行配置 =====
 # ===== 运行配置 =====
 
 
 @dataclass
 @dataclass
@@ -79,6 +105,7 @@ class RunConfig:
     temperature: float = 0.3
     temperature: float = 0.3
     max_iterations: int = 200
     max_iterations: int = 200
     tools: Optional[List[str]] = None          # None = 全部已注册工具
     tools: Optional[List[str]] = None          # None = 全部已注册工具
+    side_branch_max_turns: int = 5             # 侧分支最大轮次(压缩/反思)
 
 
     # --- 框架层参数 ---
     # --- 框架层参数 ---
     agent_type: str = "default"
     agent_type: str = "default"
@@ -279,9 +306,35 @@ class AgentRunner:
             self._cancel_events[trace.trace_id] = asyncio.Event()
             self._cancel_events[trace.trace_id] = asyncio.Event()
             yield trace
             yield trace
 
 
+            # 检查是否有未完成的侧分支(用于用户追加消息场景)
+            side_branch_ctx_for_build: Optional[SideBranchContext] = None
+            if trace.context.get("active_side_branch") and messages:
+                side_branch_data = trace.context["active_side_branch"]
+                branch_id = side_branch_data["branch_id"]
+
+                # 从数据库查询侧分支消息
+                if self.trace_store:
+                    all_messages = await self.trace_store.get_trace_messages(trace.trace_id)
+                    side_messages = [
+                        m for m in all_messages
+                        if m.branch_id == branch_id
+                    ]
+
+                    # 创建侧分支上下文(用于标记用户追加的消息)
+                    side_branch_ctx_for_build = SideBranchContext(
+                        type=side_branch_data["type"],
+                        branch_id=branch_id,
+                        start_head_seq=side_branch_data["start_head_seq"],
+                        start_sequence=side_branch_data["start_sequence"],
+                        start_history_length=0,
+                        side_messages=side_messages,
+                        max_turns=side_branch_data.get("max_turns", config.side_branch_max_turns),
+                        current_turn=side_branch_data.get("current_turn", 0),
+                    )
+
             # Phase 2: BUILD HISTORY
             # Phase 2: BUILD HISTORY
             history, sequence, created_messages, head_seq = await self._build_history(
             history, sequence, created_messages, head_seq = await self._build_history(
-                trace.trace_id, messages, goal_tree, config, sequence
+                trace.trace_id, messages, goal_tree, config, sequence, side_branch_ctx_for_build
             )
             )
             # Update trace's head_sequence in memory
             # Update trace's head_sequence in memory
             trace.head_sequence = head_seq
             trace.head_sequence = head_seq
@@ -558,7 +611,8 @@ class AgentRunner:
         goal_tree: Optional[GoalTree],
         goal_tree: Optional[GoalTree],
         config: RunConfig,
         config: RunConfig,
         sequence: int,
         sequence: int,
-    ) -> Tuple[List[Dict], int, List[Message]]:
+        side_branch_ctx: Optional[SideBranchContext] = None,
+    ) -> Tuple[List[Dict], int, List[Message], int]:
         """
         """
         构建完整的 LLM 消息历史
         构建完整的 LLM 消息历史
 
 
@@ -566,6 +620,7 @@ class AgentRunner:
         2. 构建 system prompt(新建时注入 skills)
         2. 构建 system prompt(新建时注入 skills)
         3. 新建时:在第一条 user message 末尾注入当前经验
         3. 新建时:在第一条 user message 末尾注入当前经验
         4. 追加 input messages(设置 parent_sequence 链接到当前 head)
         4. 追加 input messages(设置 parent_sequence 链接到当前 head)
+        5. 如果在侧分支中,追加的消息自动标记为侧分支消息
 
 
         Returns:
         Returns:
             (history, next_sequence, created_messages, head_sequence)
             (history, next_sequence, created_messages, head_sequence)
@@ -631,10 +686,26 @@ class AgentRunner:
             history.append(msg_dict)
             history.append(msg_dict)
 
 
             if self.trace_store:
             if self.trace_store:
-                stored_msg = Message.from_llm_dict(
-                    msg_dict, trace_id=trace_id, sequence=sequence,
-                    goal_id=None, parent_sequence=head_seq,
-                )
+                # 如果在侧分支中,标记为侧分支消息
+                if side_branch_ctx:
+                    stored_msg = Message.create(
+                        trace_id=trace_id,
+                        role=msg_dict["role"],
+                        sequence=sequence,
+                        goal_id=goal_tree.current_id if goal_tree else None,
+                        parent_sequence=head_seq,
+                        branch_type=side_branch_ctx.type,
+                        branch_id=side_branch_ctx.branch_id,
+                        content=msg_dict.get("content"),
+                    )
+                    side_branch_ctx.side_messages.append(stored_msg)
+                    logger.info(f"用户在侧分支 {side_branch_ctx.type} 中追加消息")
+                else:
+                    stored_msg = Message.from_llm_dict(
+                        msg_dict, trace_id=trace_id, sequence=sequence,
+                        goal_id=None, parent_sequence=head_seq,
+                    )
+
                 await self.trace_store.add_message(stored_msg)
                 await self.trace_store.add_message(stored_msg)
                 created_messages.append(stored_msg)
                 created_messages.append(stored_msg)
                 head_seq = sequence
                 head_seq = sequence
@@ -648,6 +719,198 @@ class AgentRunner:
 
 
     # ===== Phase 3: AGENT LOOP =====
     # ===== Phase 3: AGENT LOOP =====
 
 
+    async def _manage_context_usage(
+        self,
+        trace_id: str,
+        history: List[Dict],
+        goal_tree: Optional[GoalTree],
+        config: RunConfig,
+        sequence: int,
+        head_seq: int,
+    ) -> Tuple[List[Dict], int, int, bool]:
+        """
+        管理 context 用量:检查、预警、压缩
+
+        Returns:
+            (updated_history, new_head_seq, next_sequence, needs_enter_compression_branch)
+        """
+        compression_config = CompressionConfig()
+        token_count = estimate_tokens(history)
+        max_tokens = compression_config.get_max_tokens(config.model)
+
+        # 计算使用率
+        progress_pct = (token_count / max_tokens * 100) if max_tokens > 0 else 0
+        msg_count = len(history)
+        img_count = sum(
+            1 for msg in history
+            if isinstance(msg.get("content"), list)
+            for part in msg["content"]
+            if isinstance(part, dict) and part.get("type") in ("image", "image_url")
+        )
+
+        # 更新 context usage 快照
+        self._context_usage[trace_id] = ContextUsage(
+            trace_id=trace_id,
+            message_count=msg_count,
+            token_count=token_count,
+            max_tokens=max_tokens,
+            usage_percent=progress_pct,
+            image_count=img_count,
+        )
+
+        # 阈值警告(30%, 50%, 80%)
+        if trace_id not in self._context_warned:
+            self._context_warned[trace_id] = set()
+
+        for threshold in [30, 50, 80]:
+            if progress_pct >= threshold and threshold not in self._context_warned[trace_id]:
+                self._context_warned[trace_id].add(threshold)
+                logger.warning(
+                    f"Context 使用率达到 {threshold}%: {token_count:,} / {max_tokens:,} tokens ({msg_count} 条消息)"
+                )
+
+        # 检查是否需要压缩(token 或消息数量超限)
+        needs_compression_by_tokens = token_count > max_tokens
+        needs_compression_by_count = (
+            compression_config.max_messages > 0 and
+            msg_count > compression_config.max_messages
+        )
+        needs_compression = needs_compression_by_tokens or needs_compression_by_count
+
+        if not needs_compression:
+            return history, head_seq, sequence, False
+
+        # 知识提取:在任何压缩发生前,用完整 history 做反思(进入反思侧分支)
+        if config.knowledge.enable_extraction:
+            # 返回标志,让主循环进入反思侧分支
+            return history, head_seq, sequence, True
+
+        # Level 1 压缩:GoalTree 过滤
+        if self.trace_store and goal_tree:
+            if head_seq > 0:
+                main_path_msgs = await self.trace_store.get_main_path_messages(
+                    trace_id, head_seq
+                )
+                filtered_msgs = filter_by_goal_status(main_path_msgs, goal_tree)
+                if len(filtered_msgs) < len(main_path_msgs):
+                    logger.info(
+                        "Level 1 压缩: %d -> %d 条消息",
+                        len(main_path_msgs), len(filtered_msgs),
+                    )
+                    history = [msg.to_llm_dict() for msg in filtered_msgs]
+                else:
+                    logger.info(
+                        "Level 1 压缩: 无可过滤消息 (%d 条全部保留)",
+                        len(main_path_msgs),
+                    )
+        elif needs_compression:
+            logger.warning(
+                "消息数 (%d) 或 token 数 (%d) 超过阈值,但无法执行 Level 1 压缩(缺少 store 或 goal_tree)",
+                msg_count, token_count,
+            )
+
+        # Level 2 压缩:检查 Level 1 后是否仍超阈值
+        token_count_after = estimate_tokens(history)
+        msg_count_after = len(history)
+        needs_level2_by_tokens = token_count_after > max_tokens
+        needs_level2_by_count = (
+            compression_config.max_messages > 0 and
+            msg_count_after > compression_config.max_messages
+        )
+        needs_level2 = needs_level2_by_tokens or needs_level2_by_count
+
+        if needs_level2:
+            logger.info(
+                "Level 1 后仍超阈值 (消息数=%d/%d, token=%d/%d),需要进入压缩侧分支",
+                msg_count_after, compression_config.max_messages, token_count_after, max_tokens,
+            )
+            # 返回标志,让主循环进入压缩侧分支
+            return history, head_seq, sequence, True
+
+        # 压缩完成后,输出最终发给模型的消息列表
+        logger.info("Level 1 压缩完成,发送给模型的消息列表:")
+        for idx, msg in enumerate(history):
+            role = msg.get("role", "unknown")
+            content = msg.get("content", "")
+            if isinstance(content, str):
+                preview = content[:100] + ("..." if len(content) > 100 else "")
+            elif isinstance(content, list):
+                preview = f"[{len(content)} blocks]"
+            else:
+                preview = str(content)[:100]
+            logger.info(f"  [{idx}] {role}: {preview}")
+
+        return history, head_seq, sequence, False
+
+    async def _single_turn_compress(
+        self,
+        trace_id: str,
+        history: List[Dict],
+        goal_tree: Optional[GoalTree],
+        config: RunConfig,
+        sequence: int,
+        start_head_seq: int,
+    ) -> Tuple[List[Dict], int, int]:
+        """单次 LLM 调用压缩(fallback 方案)"""
+
+        logger.info("执行单次 LLM 压缩(fallback)")
+
+        # 构建压缩 prompt
+        compress_prompt = build_compression_prompt(goal_tree)
+        compress_messages = list(history) + [
+            {"role": "user", "content": compress_prompt}
+        ]
+
+        # 应用 Prompt Caching
+        compress_messages = self._add_cache_control(
+            compress_messages, config.model, config.enable_prompt_caching
+        )
+
+        # 单次 LLM 调用(无工具)
+        result = await self.llm_call(
+            messages=compress_messages,
+            model=config.model,
+            tools=[],  # 不提供工具
+            temperature=config.temperature,
+            **config.extra_llm_params,
+        )
+
+        summary_text = result.get("content", "").strip()
+
+        # 提取 [[SUMMARY]] 块
+        if "[[SUMMARY]]" in summary_text:
+            summary_text = summary_text[
+                summary_text.index("[[SUMMARY]]") + len("[[SUMMARY]]"):
+            ].strip()
+
+        if not summary_text:
+            logger.warning("单次压缩未返回有效内容,跳过压缩")
+            return history, start_head_seq, sequence
+
+        # 创建 summary 消息
+        summary_msg = Message.create(
+            trace_id=trace_id,
+            role="user",
+            sequence=sequence,
+            parent_sequence=start_head_seq,
+            branch_type=None,  # 主路径
+            content=f"[压缩总结 - Fallback]\n{summary_text}",
+        )
+
+        if self.trace_store:
+            await self.trace_store.add_message(summary_msg)
+
+        # 重建 history
+        system_msg = history[0] if history and history[0].get("role") == "system" else None
+        new_history = [system_msg, summary_msg.to_llm_dict()] if system_msg else [summary_msg.to_llm_dict()]
+
+        new_head_seq = sequence
+        sequence += 1
+
+        logger.info(f"单次压缩完成: {len(history)} → {len(new_history)} 条消息")
+
+        return new_history, new_head_seq, sequence
+
     async def _agent_loop(
     async def _agent_loop(
         self,
         self,
         trace: Trace,
         trace: Trace,
@@ -663,6 +926,46 @@ class AgentRunner:
         # 当前主路径头节点的 sequence(用于设置 parent_sequence)
         # 当前主路径头节点的 sequence(用于设置 parent_sequence)
         head_seq = trace.head_sequence
         head_seq = trace.head_sequence
 
 
+        # 侧分支状态(None = 主路径)
+        side_branch_ctx: Optional[SideBranchContext] = None
+
+        # 检查是否有未完成的侧分支需要恢复
+        if trace.context.get("active_side_branch"):
+            side_branch_data = trace.context["active_side_branch"]
+            branch_id = side_branch_data["branch_id"]
+
+            # 从数据库查询侧分支消息
+            if self.trace_store:
+                all_messages = await self.trace_store.get_trace_messages(trace_id)
+                side_messages = [
+                    m for m in all_messages
+                    if m.branch_id == branch_id
+                ]
+
+                # 恢复侧分支上下文
+                side_branch_ctx = SideBranchContext(
+                    type=side_branch_data["type"],
+                    branch_id=branch_id,
+                    start_head_seq=side_branch_data["start_head_seq"],
+                    start_sequence=side_branch_data["start_sequence"],
+                    start_history_length=0,  # 稍后重新计算
+                    side_messages=side_messages,
+                    max_turns=side_branch_data.get("max_turns", config.side_branch_max_turns),
+                    current_turn=side_branch_data.get("current_turn", 0),
+                )
+
+                logger.info(
+                    f"恢复未完成的侧分支: {side_branch_ctx.type}, "
+                    f"已执行 {side_branch_ctx.current_turn}/{side_branch_ctx.max_turns} 轮"
+                )
+
+                # 将侧分支消息追加到 history
+                for m in side_messages:
+                    history.append(m.to_llm_dict())
+
+                # 重新计算 start_history_length
+                side_branch_ctx.start_history_length = len(history) - len(side_messages)
+
         for iteration in range(config.max_iterations):
         for iteration in range(config.max_iterations):
             # 更新活动时间(表明trace正在活跃运行)
             # 更新活动时间(表明trace正在活跃运行)
             if self.trace_store:
             if self.trace_store:
@@ -693,114 +996,74 @@ class AgentRunner:
                         yield trace_obj
                         yield trace_obj
                 return
                 return
 
 
-            # Level 1 压缩:GoalTree 过滤(当消息超过阈值时触发)
-            compression_config = CompressionConfig()
-            token_count = estimate_tokens(history)
-            max_tokens = compression_config.get_max_tokens(config.model)
-
-            # 计算使用率
-            progress_pct = (token_count / max_tokens * 100) if max_tokens > 0 else 0
-            msg_count = len(history)
-            img_count = sum(
-                1 for msg in history
-                if isinstance(msg.get("content"), list)
-                for part in msg["content"]
-                if isinstance(part, dict) and part.get("type") in ("image", "image_url")
-            )
-
-            # 更新 context usage 快照
-            self._context_usage[trace_id] = ContextUsage(
-                trace_id=trace_id,
-                message_count=msg_count,
-                token_count=token_count,
-                max_tokens=max_tokens,
-                usage_percent=progress_pct,
-                image_count=img_count,
-            )
+            # Context 管理(仅主路径)
+            needs_enter_side_branch = False
+            if not side_branch_ctx:
+                history, head_seq, sequence, needs_enter_side_branch = await self._manage_context_usage(
+                    trace_id, history, goal_tree, config, sequence, head_seq
+                )
 
 
-            # 阈值警告(30%, 50%, 80%)
-            if trace_id not in self._context_warned:
-                self._context_warned[trace_id] = set()
+            # 进入侧分支
+            if needs_enter_side_branch and not side_branch_ctx:
+                # 判断侧分支类型:反思 or 压缩
+                branch_type = "reflection" if config.knowledge.enable_extraction else "compression"
+                branch_id = f"{branch_type}_{uuid.uuid4().hex[:8]}"
+
+                side_branch_ctx = SideBranchContext(
+                    type=branch_type,
+                    branch_id=branch_id,
+                    start_head_seq=head_seq,
+                    start_sequence=sequence,
+                    start_history_length=len(history),
+                    side_messages=[],
+                    max_turns=config.side_branch_max_turns,
+                    current_turn=0,
+                )
 
 
-            for threshold in [30, 50, 80]:
-                if progress_pct >= threshold and threshold not in self._context_warned[trace_id]:
-                    self._context_warned[trace_id].add(threshold)
-                    logger.warning(
-                        f"Context 使用率达到 {threshold}%: {token_count:,} / {max_tokens:,} tokens ({msg_count} 条消息)"
+                # 持久化侧分支状态
+                if self.trace_store:
+                    trace.context["active_side_branch"] = {
+                        "type": side_branch_ctx.type,
+                        "branch_id": side_branch_ctx.branch_id,
+                        "start_head_seq": side_branch_ctx.start_head_seq,
+                        "start_sequence": side_branch_ctx.start_sequence,
+                        "max_turns": side_branch_ctx.max_turns,
+                        "current_turn": 0,
+                        "started_at": datetime.now().isoformat(),
+                    }
+                    await self.trace_store.update_trace(
+                        trace_id,
+                        context=trace.context
                     )
                     )
 
 
-            # 检查是否需要压缩(token 或消息数量超限)
-            needs_compression_by_tokens = token_count > max_tokens
-            needs_compression_by_count = (
-                compression_config.max_messages > 0 and
-                msg_count > compression_config.max_messages
-            )
-            needs_compression = needs_compression_by_tokens or needs_compression_by_count
-
-            # 知识提取:在任何压缩发生前,用完整 history 做反思
-            if needs_compression and config.knowledge.enable_extraction:
-                await self._run_reflect(
-                    trace_id, history, config,
-                    reflect_prompt=config.knowledge.get_reflect_prompt(),
-                    source_name="compression_reflection",
-                )
+                # 追加侧分支 prompt
+                if branch_type == "reflection":
+                    prompt = config.knowledge.get_reflect_prompt()
+                else:  # compression
+                    from agent.trace.compaction import build_compression_prompt
+                    prompt = build_compression_prompt(goal_tree)
 
 
-            # Level 1 压缩:GoalTree 过滤
-            if needs_compression and self.trace_store and goal_tree:
-                if head_seq > 0:
-                    main_path_msgs = await self.trace_store.get_main_path_messages(
-                        trace_id, head_seq
-                    )
-                    filtered_msgs = filter_by_goal_status(main_path_msgs, goal_tree)
-                    if len(filtered_msgs) < len(main_path_msgs):
-                        logger.info(
-                            "Level 1 压缩: %d -> %d 条消息",
-                            len(main_path_msgs), len(filtered_msgs),
-                        )
-                        history = [msg.to_llm_dict() for msg in filtered_msgs]
-                    else:
-                        logger.info(
-                            "Level 1 压缩: 无可过滤消息 (%d 条全部保留)",
-                            len(main_path_msgs),
-                        )
-            elif needs_compression:
-                logger.warning(
-                    "消息数 (%d) 或 token 数 (%d) 超过阈值,但无法执行 Level 1 压缩(缺少 store 或 goal_tree)",
-                    msg_count, token_count,
+                branch_user_msg = Message.create(
+                    trace_id=trace_id,
+                    role="user",
+                    sequence=sequence,
+                    parent_sequence=head_seq,
+                    goal_id=goal_tree.current_id if goal_tree else None,
+                    branch_type=branch_type,
+                    branch_id=branch_id,
+                    content=prompt,
                 )
                 )
 
 
-            # Level 2 压缩:LLM 总结(Level 1 后仍超阈值时触发)
-            token_count_after = estimate_tokens(history)
-            msg_count_after = len(history)
-            needs_level2_by_tokens = token_count_after > max_tokens
-            needs_level2_by_count = (
-                compression_config.max_messages > 0 and
-                msg_count_after > compression_config.max_messages
-            )
-            needs_level2 = needs_level2_by_tokens or needs_level2_by_count
+                if self.trace_store:
+                    await self.trace_store.add_message(branch_user_msg)
 
 
-            if needs_level2:
-                logger.info(
-                    "Level 1 后仍超阈值 (消息数=%d/%d, token=%d/%d),触发 Level 2 压缩",
-                    msg_count_after, compression_config.max_messages, token_count_after, max_tokens,
-                )
-                history, head_seq, sequence = await self._compress_history(
-                    trace_id, history, goal_tree, config, sequence, head_seq,
-                )
+                history.append(branch_user_msg.to_llm_dict())
+                side_branch_ctx.side_messages.append(branch_user_msg)
+                head_seq = sequence
+                sequence += 1
 
 
-            # 压缩完成后,输出最终发给模型的消息列表
-            if needs_compression:
-                logger.info("压缩完成,发送给模型的消息列表:")
-                for idx, msg in enumerate(history):
-                    role = msg.get("role", "unknown")
-                    content = msg.get("content", "")
-                    if isinstance(content, str):
-                        preview = content[:100] + ("..." if len(content) > 100 else "")
-                    elif isinstance(content, list):
-                        preview = f"[{len(content)} blocks]"
-                    else:
-                        preview = str(content)[:100]
-                    logger.info(f"  [{idx}] {role}: {preview}")
+                logger.info(f"进入侧分支: {branch_type}, branch_id={branch_id}")
+                continue  # 跳过本轮,下一轮开始侧分支
 
 
             # 构建 LLM messages(注入上下文)
             # 构建 LLM messages(注入上下文)
             llm_messages = list(history)
             llm_messages = list(history)
@@ -813,7 +1076,8 @@ class AgentRunner:
             )
             )
 
 
             # 周期性注入 GoalTree + Collaborators(动态内容追加在缓存点之后)
             # 周期性注入 GoalTree + Collaborators(动态内容追加在缓存点之后)
-            if iteration % CONTEXT_INJECTION_INTERVAL == 0:
+            # 仅在主路径执行
+            if not side_branch_ctx and iteration % CONTEXT_INJECTION_INTERVAL == 0:
                 context_injection = self._build_context_injection(trace, goal_tree)
                 context_injection = self._build_context_injection(trace, goal_tree)
                 if context_injection:
                 if context_injection:
                     system_msg = {"role": "system", "content": context_injection}
                     system_msg = {"role": "system", "content": context_injection}
@@ -854,8 +1118,8 @@ class AgentRunner:
             cache_creation_tokens = result.get("cache_creation_tokens")
             cache_creation_tokens = result.get("cache_creation_tokens")
             cache_read_tokens = result.get("cache_read_tokens")
             cache_read_tokens = result.get("cache_read_tokens")
 
 
-            # 按需自动创建 root goal
-            if goal_tree and not goal_tree.goals and tool_calls:
+            # 按需自动创建 root goal(仅主路径)
+            if not side_branch_ctx and goal_tree and not goal_tree.goals and tool_calls:
                 has_goal_call = any(
                 has_goal_call = any(
                     tc.get("function", {}).get("name") == "goal"
                     tc.get("function", {}).get("name") == "goal"
                     for tc in tool_calls
                     for tc in tool_calls
@@ -886,6 +1150,8 @@ class AgentRunner:
                 sequence=sequence,
                 sequence=sequence,
                 goal_id=current_goal_id,
                 goal_id=current_goal_id,
                 parent_sequence=head_seq if head_seq > 0 else None,
                 parent_sequence=head_seq if head_seq > 0 else None,
+                branch_type=side_branch_ctx.type if side_branch_ctx else None,
+                branch_id=side_branch_ctx.branch_id if side_branch_ctx else None,
                 content={"text": response_content, "tool_calls": tool_calls},
                 content={"text": response_content, "tool_calls": tool_calls},
                 prompt_tokens=prompt_tokens,
                 prompt_tokens=prompt_tokens,
                 completion_tokens=completion_tokens,
                 completion_tokens=completion_tokens,
@@ -900,7 +1166,7 @@ class AgentRunner:
                 # 记录模型使用
                 # 记录模型使用
                 await self.trace_store.record_model_usage(
                 await self.trace_store.record_model_usage(
                     trace_id=trace_id,
                     trace_id=trace_id,
-                    sequence=sequence - 1,  # assistant_msg的sequence
+                    sequence=sequence,
                     role="assistant",
                     role="assistant",
                     model=config.model,
                     model=config.model,
                     prompt_tokens=prompt_tokens,
                     prompt_tokens=prompt_tokens,
@@ -908,10 +1174,152 @@ class AgentRunner:
                     cache_read_tokens=cache_read_tokens or 0,
                     cache_read_tokens=cache_read_tokens or 0,
                 )
                 )
 
 
+            # 如果在侧分支,记录到 side_messages
+            if side_branch_ctx:
+                side_branch_ctx.side_messages.append(assistant_msg)
+
             yield assistant_msg
             yield assistant_msg
             head_seq = sequence
             head_seq = sequence
             sequence += 1
             sequence += 1
 
 
+            # 检查侧分支是否应该退出
+            if side_branch_ctx:
+                side_branch_ctx.current_turn += 1
+
+                # 更新持久化状态
+                if self.trace_store:
+                    trace.context["active_side_branch"]["current_turn"] = side_branch_ctx.current_turn
+                    await self.trace_store.update_trace(
+                        trace_id,
+                        context=trace.context
+                    )
+
+                # 检查是否达到最大轮次
+                if side_branch_ctx.current_turn >= side_branch_ctx.max_turns:
+                    logger.warning(
+                        f"侧分支 {side_branch_ctx.type} 达到最大轮次 "
+                        f"{side_branch_ctx.max_turns},强制退出"
+                    )
+
+                    if side_branch_ctx.type == "compression":
+                        # 压缩侧分支:fallback 到单次 LLM 调用
+                        logger.info("Fallback 到单次 LLM 压缩")
+
+                        # 清除侧分支状态
+                        trace.context.pop("active_side_branch", None)
+                        if self.trace_store:
+                            await self.trace_store.update_trace(
+                                trace_id, context=trace.context
+                            )
+
+                        # 恢复到侧分支开始前的 history
+                        if self.trace_store:
+                            main_path_messages = await self.trace_store.get_main_path_messages(
+                                trace_id, side_branch_ctx.start_head_seq
+                            )
+                            history = [m.to_llm_dict() for m in main_path_messages]
+
+                        # 执行单次 LLM 压缩
+                        history, head_seq, sequence = await self._single_turn_compress(
+                            trace_id, history, goal_tree, config, sequence,
+                            side_branch_ctx.start_head_seq
+                        )
+
+                        side_branch_ctx = None
+                        continue
+
+                    elif side_branch_ctx.type == "reflection":
+                        # 反思侧分支:直接退出,不管结果
+                        logger.info("反思侧分支超时,直接退出")
+
+                        # 清除侧分支状态
+                        trace.context.pop("active_side_branch", None)
+                        if self.trace_store:
+                            await self.trace_store.update_trace(
+                                trace_id, context=trace.context
+                            )
+
+                        # 恢复到侧分支开始前的 history
+                        if self.trace_store:
+                            main_path_messages = await self.trace_store.get_main_path_messages(
+                                trace_id, side_branch_ctx.start_head_seq
+                            )
+                            history = [m.to_llm_dict() for m in main_path_messages]
+                            head_seq = side_branch_ctx.start_head_seq
+
+                        side_branch_ctx = None
+                        continue
+
+                # 检查是否无工具调用(侧分支完成)
+                if not tool_calls:
+                    logger.info(f"侧分支 {side_branch_ctx.type} 完成(无工具调用)")
+
+                    # 提取结果
+                    if side_branch_ctx.type == "compression":
+                        # 从侧分支消息中提取 summary
+                        summary_text = ""
+                        for msg in side_branch_ctx.side_messages:
+                            if msg.role == "assistant" and isinstance(msg.content, dict):
+                                text = msg.content.get("text", "")
+                                if "[[SUMMARY]]" in text:
+                                    summary_text = text[text.index("[[SUMMARY]]") + len("[[SUMMARY]]"):].strip()
+                                    break
+                                elif text:
+                                    summary_text = text
+
+                        if not summary_text:
+                            logger.warning("侧分支未生成有效 summary,使用默认")
+                            summary_text = "压缩完成"
+
+                        # 创建主路径的 summary 消息
+                        summary_msg = Message.create(
+                            trace_id=trace_id,
+                            role="user",
+                            sequence=sequence,
+                            parent_sequence=side_branch_ctx.start_head_seq,
+                            branch_type=None,  # 回到主路径
+                            content=f"[压缩总结]\n{summary_text}",
+                        )
+
+                        if self.trace_store:
+                            await self.trace_store.add_message(summary_msg)
+
+                        # 重建 history
+                        if self.trace_store:
+                            main_path_messages = await self.trace_store.get_main_path_messages(
+                                trace_id, side_branch_ctx.start_head_seq
+                            )
+                            history = [m.to_llm_dict() for m in main_path_messages]
+
+                        history.append(summary_msg.to_llm_dict())
+                        head_seq = sequence
+                        sequence += 1
+
+                        logger.info(f"压缩侧分支完成,history 长度: {len(history)}")
+
+                    elif side_branch_ctx.type == "reflection":
+                        # 反思侧分支:直接恢复主路径
+                        logger.info("反思侧分支完成")
+
+                        if self.trace_store:
+                            main_path_messages = await self.trace_store.get_main_path_messages(
+                                trace_id, side_branch_ctx.start_head_seq
+                            )
+                            history = [m.to_llm_dict() for m in main_path_messages]
+                            head_seq = side_branch_ctx.start_head_seq
+
+                    # 清除侧分支状态
+                    trace.context.pop("active_side_branch", None)
+                    if self.trace_store:
+                        await self.trace_store.update_trace(
+                            trace_id,
+                            context=trace.context,
+                            head_sequence=head_seq,
+                        )
+
+                    side_branch_ctx = None
+                    continue
+
             # 处理工具调用
             # 处理工具调用
             # 截断兜底:finish_reason == "length" 说明响应被 max_tokens 截断,
             # 截断兜底:finish_reason == "length" 说明响应被 max_tokens 截断,
             # tool call 参数很可能不完整,不应执行,改为提示模型分批操作
             # tool call 参数很可能不完整,不应执行,改为提示模型分批操作
@@ -969,6 +1377,14 @@ class AgentRunner:
                             "runner": self,
                             "runner": self,
                             "goal_tree": goal_tree,
                             "goal_tree": goal_tree,
                             "knowledge_config": config.knowledge,
                             "knowledge_config": config.knowledge,
+                            # 新增:侧分支信息
+                            "side_branch": {
+                                "type": side_branch_ctx.type,
+                                "branch_id": side_branch_ctx.branch_id,
+                                "is_side_branch": True,
+                                "current_turn": side_branch_ctx.current_turn,
+                                "max_turns": side_branch_ctx.max_turns,
+                            } if side_branch_ctx else None,
                         },
                         },
                     )
                     )
 
 
@@ -1023,6 +1439,8 @@ class AgentRunner:
                         goal_id=current_goal_id,
                         goal_id=current_goal_id,
                         parent_sequence=head_seq,
                         parent_sequence=head_seq,
                         tool_call_id=tc["id"],
                         tool_call_id=tc["id"],
+                        branch_type=side_branch_ctx.type if side_branch_ctx else None,
+                        branch_id=side_branch_ctx.branch_id if side_branch_ctx else None,
                         # 存储完整内容:有图片时保留 list(含 image_url),纯文本时存字符串
                         # 存储完整内容:有图片时保留 list(含 image_url),纯文本时存字符串
                         content={"tool_name": tool_name, "result": tool_content_for_llm},
                         content={"tool_name": tool_name, "result": tool_content_for_llm},
                     )
                     )
@@ -1051,6 +1469,10 @@ class AgentRunner:
                                     print(f"[Runner] 截图已保存: {png_path.name}")
                                     print(f"[Runner] 截图已保存: {png_path.name}")
                                     break  # 只存第一张
                                     break  # 只存第一张
 
 
+                    # 如果在侧分支,记录到 side_messages
+                    if side_branch_ctx:
+                        side_branch_ctx.side_messages.append(tool_msg)
+
                     yield tool_msg
                     yield tool_msg
                     head_seq = sequence
                     head_seq = sequence
                     sequence += 1
                     sequence += 1

+ 16 - 7
agent/docs/architecture.md

@@ -523,15 +523,23 @@ class Message:
 正常对话:1 → 2 → 3 → 4 → 5       (每条的 parent 指向前一条)
 正常对话:1 → 2 → 3 → 4 → 5       (每条的 parent 指向前一条)
 Rewind 到 3:3 → 6(parent=3) → 7   (新主路径,4-5 自动脱离)
 Rewind 到 3:3 → 6(parent=3) → 7   (新主路径,4-5 自动脱离)
 压缩 1-3:   8(summary, parent=None) → 6 → 7  (summary 跳过被压缩的消息)
 压缩 1-3:   8(summary, parent=None) → 6 → 7  (summary 跳过被压缩的消息)
-反思分支:   5 → 9(reflect, parent=5) → 10     (侧枝,不在主路径上)
+侧分支:     5 → 6(branch_type="compression", parent=5) → 7(parent=6)
+            5 → 8(summary, parent=5, 主路径)
+            (侧分支消息 6-7 通过 parent_sequence 自然脱离主路径)
 ```
 ```
 
 
-`build_llm_messages` = 从 head 沿 parent_sequence 链回溯到 root,反转后返回。
+`build_llm_messages` = 从 `trace.head_sequence` 沿 parent_sequence 链回溯到 root,反转后返回。
+
+**关键设计**:只要 `trace.head_sequence` 管理正确(始终指向主路径),`get_main_path_messages()` 自然返回主路径消息,侧分支消息通过 parent_sequence 链自动被跳过,无需额外过滤。
 
 
 Message 提供格式转换方法:
 Message 提供格式转换方法:
 - `to_llm_dict()` → OpenAI 格式 Dict(用于 LLM 调用)
 - `to_llm_dict()` → OpenAI 格式 Dict(用于 LLM 调用)
 - `from_llm_dict(d, trace_id, sequence, goal_id)` → 从 OpenAI 格式创建 Message
 - `from_llm_dict(d, trace_id, sequence, goal_id)` → 从 OpenAI 格式创建 Message
 
 
+**侧分支字段**:
+- `branch_type`: "compression" | "reflection" | None(主路径)
+- `branch_id`: 同一侧分支的消息共享 branch_id
+
 **实现**:`agent/trace/models.py`
 **实现**:`agent/trace/models.py`
 
 
 ---
 ---
@@ -1246,8 +1254,10 @@ async def get_experience(
 触发条件:Level 1 之后 token 数仍超过阈值(默认 `max_tokens × 0.8`)。
 触发条件:Level 1 之后 token 数仍超过阈值(默认 `max_tokens × 0.8`)。
 
 
 流程:
 流程:
-1. **经验提取**:先在消息列表末尾追加反思 prompt → 主模型回复 → 追加到 `./.cache/experiences.md`。反思消息为侧枝(parent_sequence 分叉,不在主路径上)
-2. **压缩**:在消息列表末尾追加压缩 prompt(含 GoalTree 完整视图) → 主模型回复 → summary 存为新消息,其 `parent_sequence` 跳过被压缩的范围
+1. **经验提取**:在消息列表末尾追加反思 prompt,进入侧分支 agent 模式(最多 5 轮),LLM 可调用工具(如 knowledge_search, knowledge_save)进行多轮推理。反思消息标记为 `branch_type="reflection"`,不在主路径上
+2. **压缩**:在消息列表末尾追加压缩 prompt(含 GoalTree 完整视图),进入侧分支 agent 模式(最多 5 轮),LLM 可调用工具(如 goal_status)辅助压缩。压缩消息标记为 `branch_type="compression"`,完成后创建 summary 消息,其 `parent_sequence` 跳过被压缩的范围
+
+**侧分支模式**:压缩和反思在同一 agent loop 中通过状态机实现,复用主路径的缓存和工具配置,支持多轮推理。
 
 
 ### GoalTree 双视图
 ### GoalTree 双视图
 
 
@@ -1259,12 +1269,11 @@ async def get_experience(
 
 
 - 原始消息永远保留在 `messages/`
 - 原始消息永远保留在 `messages/`
 - 压缩 summary 作为普通 Message 存储
 - 压缩 summary 作为普通 Message 存储
+- 侧分支消息通过 `branch_type` 和 `branch_id` 标记,查询主路径时自动过滤
 - 通过 `parent_sequence` 树结构实现跳过,无需 compression events 或 skip list
 - 通过 `parent_sequence` 树结构实现跳过,无需 compression events 或 skip list
 - Rewind 到压缩区域内时,summary 脱离主路径,原始消息自动恢复
 - Rewind 到压缩区域内时,summary 脱离主路径,原始消息自动恢复
 
 
-**实现**:`agent/trace/compaction.py`, `agent/trace/goal_models.py`
-
-**详细文档**:[Context 管理](./context-management.md)
+**实现**:`agent/core/runner.py:_agent_loop`, `agent/trace/compaction.py`, `agent/trace/goal_models.py`
 
 
 ---
 ---
 
 

+ 149 - 0
agent/docs/decisions.md

@@ -1156,4 +1156,153 @@ Rewind 事件 payload 中增加 `head_sequence` 字段,便于前端感知分
 
 
 **实现**:`agent/trace/run_api.py`, `agent/core/runner.py`, `agent/trace/api.py`
 **实现**:`agent/trace/run_api.py`, `agent/core/runner.py`, `agent/trace/api.py`
 
 
+---
+
+## Decision 24: 侧分支多轮 Agent 模式
+
+**日期**: 2026-03-09
+
+### 问题
+
+原有的压缩和反思使用单轮 LLM 调用,但这些任务可能需要多轮推理和工具调用才能做好:
+- **压缩**:可能需要查询 goal_tree 状态、分步总结
+- **反思**:可能需要先分析失败原因、再提取经验,或检查知识库避免重复
+
+单轮调用限制了 LLM 的推理能力,且改变 system prompt 或工具清单会导致缓存失效。
+
+### 决策
+
+**选择:侧分支在同一 agent loop 中以状态机模式运行**
+
+#### 24a. 核心设计
+
+侧分支不是递归调用 `_agent_loop`,而是在同一个循环中通过状态切换实现:
+
+```python
+# 主循环维护侧分支状态
+side_branch_ctx: Optional[SideBranchContext] = None
+
+for iteration in range(max_iterations):
+    # 进入侧分支:追加 prompt,设置状态
+    if needs_compression and not side_branch_ctx:
+        side_branch_ctx = SideBranchContext(...)
+        history.append({"role": "user", "content": compress_prompt})
+        continue
+
+    # 侧分支中:正常执行 LLM 调用和工具执行
+    result = await self.llm_call(history, tools=..., model=...)
+
+    # 退出侧分支:提取结果,回到起点
+    if side_branch_ctx and not tool_calls:
+        summary = extract_summary(side_branch_ctx.side_messages)
+        history = history[:side_branch_ctx.start_history_length]
+        # 创建主路径 summary 消息
+        side_branch_ctx = None
+        continue
+```
+
+**优势**:
+1. **缓存友好**:复用主路径的所有缓存,只有追加的 prompt 是新内容
+2. **工具自然可用**:不需要单独配置工具清单,agent 自由选择需要的工具
+3. **实现简洁**:不需要递归调用,状态管理清晰
+
+#### 24b. 侧分支上下文结构
+
+```python
+@dataclass
+class SideBranchContext:
+    type: str  # "compression" | "reflection"
+    branch_id: str
+    start_head_seq: int  # 起点的 head_seq
+    start_sequence: int  # 起点的 sequence
+    start_history_length: int  # 起点的 history 长度
+    side_messages: List[Message]  # 侧分支产生的消息
+    max_turns: int = 5  # 侧分支最大轮次
+    current_turn: int = 0  # 当前轮次
+```
+
+#### 24c. 消息标记
+
+侧分支产生的消息通过 `branch_type` 和 `branch_id` 字段标记:
+- `branch_type`: "compression" | "reflection" | None(主路径)
+- `branch_id`: 同一侧分支的消息共享 branch_id
+- `parent_sequence`: 侧分支消息的 parent 指向主路径或前一条侧分支消息
+
+**关键设计**:`trace.head_sequence` 始终指向主路径的头节点。侧分支执行期间,`head_sequence` 保持在侧分支起点,不更新。侧分支完成后,创建主路径 summary 消息(parent 指向起点),然后更新 `head_sequence` 指向 summary。
+
+这样设计的好处:
+- `get_main_path_messages(trace_id, head_sequence)` 自然返回主路径消息
+- 侧分支消息通过 parent_sequence 链自动脱离主路径,无需额外过滤
+- 续跑时自动加载正确的主路径历史
+
+#### 24d. 停止条件
+
+侧分支使用与主 agent 相同的停止逻辑:
+- LLM 返回无工具调用 → 认为完成
+- 达到 `config.side_branch_max_turns` → 强制停止并处理:
+  - **压缩侧分支**:fallback 到单次 LLM 调用(无工具)
+  - **反思侧分支**:直接退出,不管结果
+
+用户在侧分支中追加的消息自动标记为侧分支消息,继续在侧分支中执行。
+
+#### 24e. 工具 context 传递
+
+侧分支信息通过 `context` 参数传递给工具,保持框架一致性:
+
+```python
+context = {
+    "store": self.trace_store,
+    "trace_id": trace_id,
+    "goal_id": current_goal_id,
+    "runner": self,
+    "goal_tree": goal_tree,
+    "knowledge_config": config.knowledge,
+    # 新增:侧分支信息
+    "side_branch": {
+        "type": side_branch_ctx.type,
+        "branch_id": side_branch_ctx.branch_id,
+        "is_side_branch": True,
+        "current_turn": side_branch_ctx.current_turn,
+        "max_turns": side_branch_ctx.max_turns,
+    } if side_branch_ctx else None,
+}
+```
+
+工具可以通过 `context.get("side_branch")` 感知自己是否在侧分支中执行,但当前不需要特殊处理。
+
+#### 24f. 主循环重构
+
+为避免主循环过于复杂,提取以下函数:
+- `_manage_context_usage()`: Context 用量检查、预警、压缩(整合 Level 1/2)
+- `_check_enter_side_branch()`: 检查是否需要进入侧分支
+- `_check_exit_side_branch()`: 检查是否需要退出侧分支
+- `_exit_side_branch()`: 执行退出逻辑(回到起点)
+- `_single_turn_compress()`: 单次 LLM 压缩(fallback 方案)
+
+主循环通过 `if not side_branch_ctx` 控制哪些逻辑只在主路径执行。
+
+#### 24g. 侧分支状态持久化
+
+侧分支状态存储在 `trace.context["active_side_branch"]`:
+- 进入侧分支时创建,记录 `max_turns`(来自 `RunConfig.side_branch_max_turns`,默认 5)
+- 每轮结束时更新 `current_turn`
+- 退出侧分支时清除
+- 续跑时自动恢复,使用持久化的 `max_turns` 值
+
+这确保了中断后可以继续完成侧分支,不浪费已执行的 LLM 调用。
+
+#### 24h. RunConfig 配置
+
+新增字段:
+- `side_branch_max_turns: int = 5` — 侧分支最大轮次,超过后强制退出
+
+### 变更范围
+
+- `agent/trace/models.py` — Message 增加 `branch_type` 和 `branch_id` 字段
+- `agent/core/runner.py` — 增加 `SideBranchContext`,重构 `_agent_loop`
+- `agent/trace/compaction.py` — `_compress_history` 改为状态机模式
+- `agent/trace/protocols.py` — 查询接口支持过滤侧分支消息
+
+**实现**:`agent/core/runner.py:_agent_loop`, `agent/trace/models.py:Message`, `agent/trace/compaction.py`
+
 ---
 ---

+ 14 - 0
agent/trace/models.py

@@ -177,6 +177,10 @@ class Message:
     tool_call_id: Optional[str] = None   # tool 消息关联对应的 tool_call
     tool_call_id: Optional[str] = None   # tool 消息关联对应的 tool_call
     content: Any = None                  # 消息内容(和 LLM API 格式一致)
     content: Any = None                  # 消息内容(和 LLM API 格式一致)
 
 
+    # 侧分支标记
+    branch_type: Optional[Literal["compression", "reflection"]] = None  # 侧分支类型(None = 主路径)
+    branch_id: Optional[str] = None      # 侧分支 ID(同一侧分支的消息共享)
+
     # 元数据
     # 元数据
     prompt_tokens: Optional[int] = None  # 输入 tokens
     prompt_tokens: Optional[int] = None  # 输入 tokens
     completion_tokens: Optional[int] = None  # 输出 tokens
     completion_tokens: Optional[int] = None  # 输出 tokens
@@ -294,6 +298,12 @@ class Message:
         if "parent_sequence" not in filtered_data:
         if "parent_sequence" not in filtered_data:
             filtered_data["parent_sequence"] = None
             filtered_data["parent_sequence"] = None
 
 
+        # 向后兼容:旧消息没有侧分支字段
+        if "branch_type" not in filtered_data:
+            filtered_data["branch_type"] = None
+        if "branch_id" not in filtered_data:
+            filtered_data["branch_id"] = None
+
         return cls(**filtered_data)
         return cls(**filtered_data)
 
 
     @classmethod
     @classmethod
@@ -306,6 +316,8 @@ class Message:
         content: Any = None,
         content: Any = None,
         tool_call_id: Optional[str] = None,
         tool_call_id: Optional[str] = None,
         parent_sequence: Optional[int] = None,
         parent_sequence: Optional[int] = None,
+        branch_type: Optional[Literal["compression", "reflection"]] = None,
+        branch_id: Optional[str] = None,
         prompt_tokens: Optional[int] = None,
         prompt_tokens: Optional[int] = None,
         completion_tokens: Optional[int] = None,
         completion_tokens: Optional[int] = None,
         reasoning_tokens: Optional[int] = None,
         reasoning_tokens: Optional[int] = None,
@@ -328,6 +340,8 @@ class Message:
             content=content,
             content=content,
             description=description,
             description=description,
             tool_call_id=tool_call_id,
             tool_call_id=tool_call_id,
+            branch_type=branch_type,
+            branch_id=branch_id,
             prompt_tokens=prompt_tokens,
             prompt_tokens=prompt_tokens,
             completion_tokens=completion_tokens,
             completion_tokens=completion_tokens,
             reasoning_tokens=reasoning_tokens,
             reasoning_tokens=reasoning_tokens,

+ 6 - 2
agent/trace/store.py

@@ -521,10 +521,14 @@ class FileSystemTraceStore:
         head_sequence: int
         head_sequence: int
     ) -> List[Message]:
     ) -> List[Message]:
         """
         """
-        获取主路径上的消息(从 head_sequence 沿 parent_sequence 链回溯到 root)
+        获取从 head_sequence 沿 parent_sequence 链回溯到 root 的完整路径
+
+        此函数是通用的路径追溯函数,返回从指定 head 到 root 的完整消息链。
+        只要 trace.head_sequence 管理正确(指向主路径),此函数自然返回主路径消息。
+        侧分支消息通过 parent_sequence 链自然被跳过(因为主路径的 parent 不指向侧分支)。
 
 
         Returns:
         Returns:
-            按 sequence 正序排列的主路径 Message 列表
+            按 sequence 正序排列的路径 Message 列表
         """
         """
         # 加载所有消息,建立 sequence -> Message 索引
         # 加载所有消息,建立 sequence -> Message 索引
         all_messages = await self.get_trace_messages(trace_id)
         all_messages = await self.get_trace_messages(trace_id)

+ 10 - 4
knowhub/embeddings.py

@@ -10,12 +10,19 @@ import asyncio
 from typing import List, Union
 from typing import List, Union
 import httpx
 import httpx
 
 
-OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
 OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
 OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
 EMBEDDING_MODEL = "openai/text-embedding-3-small"
 EMBEDDING_MODEL = "openai/text-embedding-3-small"
 EMBEDDING_DIM = 1536
 EMBEDDING_DIM = 1536
 
 
 
 
+def _get_api_key() -> str:
+    """获取 API key(延迟读取环境变量)"""
+    key = os.getenv("OPENROUTER_API_KEY") or os.getenv("OPEN_ROUTER_API_KEY")
+    if not key:
+        raise ValueError("OPENROUTER_API_KEY or OPEN_ROUTER_API_KEY not set in environment")
+    return key
+
+
 async def get_embedding(text: str) -> List[float]:
 async def get_embedding(text: str) -> List[float]:
     """
     """
     生成单条文本的向量
     生成单条文本的向量
@@ -64,14 +71,13 @@ async def _call_embedding_api(texts: List[str]) -> List[List[float]]:
     Returns:
     Returns:
         向量列表
         向量列表
     """
     """
-    if not OPENROUTER_API_KEY:
-        raise ValueError("OPENROUTER_API_KEY not set in environment")
+    api_key = _get_api_key()
 
 
     async with httpx.AsyncClient(timeout=30.0) as client:
     async with httpx.AsyncClient(timeout=30.0) as client:
         response = await client.post(
         response = await client.post(
             f"{OPENROUTER_BASE_URL}/embeddings",
             f"{OPENROUTER_BASE_URL}/embeddings",
             headers={
             headers={
-                "Authorization": f"Bearer {OPENROUTER_API_KEY}",
+                "Authorization": f"Bearer {api_key}",
                 "Content-Type": "application/json",
                 "Content-Type": "application/json",
             },
             },
             json={
             json={

+ 1 - 1
knowhub/server.py

@@ -1625,4 +1625,4 @@ def frontend():
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     import uvicorn
     import uvicorn
-    uvicorn.run(app, host="0.0.0.0", port=9999)
+    uvicorn.run(app, host="0.0.0.0", port=9998)

+ 59 - 0
test_embeddings.py

@@ -0,0 +1,59 @@
+"""
+测试 Embeddings 模块(不依赖 Milvus)
+"""
+
+import asyncio
+import sys
+from pathlib import Path
+
+sys.path.insert(0, str(Path(__file__).parent))
+
+# 加载环境变量
+from dotenv import load_dotenv
+load_dotenv(Path(__file__).parent / ".env")
+
+from knowhub.embeddings import get_embedding, get_embeddings_batch
+
+
+async def test_embeddings():
+    print("=" * 60)
+    print("测试 Embeddings 模块")
+    print("=" * 60)
+
+    # 测试单条
+    print("\n1. 测试单条 embedding 生成...")
+    text = "如何使用 Python 读取 PDF 文件"
+    try:
+        embedding = await get_embedding(text)
+        print(f"✓ 成功生成 embedding")
+        print(f"  文本: {text}")
+        print(f"  向量维度: {len(embedding)}")
+        print(f"  前 5 个值: {embedding[:5]}")
+    except Exception as e:
+        print(f"✗ 失败: {e}")
+        return
+
+    # 测试批量
+    print("\n2. 测试批量 embedding 生成...")
+    texts = [
+        "使用 pymupdf 读取 PDF",
+        "使用 selenium 进行网页自动化",
+        "使用 pandas 处理数据"
+    ]
+    try:
+        embeddings = await get_embeddings_batch(texts)
+        print(f"✓ 成功生成批量 embeddings")
+        print(f"  文本数量: {len(texts)}")
+        print(f"  向量数量: {len(embeddings)}")
+        print(f"  每个向量维度: {len(embeddings[0])}")
+    except Exception as e:
+        print(f"✗ 失败: {e}")
+        return
+
+    print("\n" + "=" * 60)
+    print("Embeddings 模块测试通过!")
+    print("=" * 60)
+
+
+if __name__ == "__main__":
+    asyncio.run(test_embeddings())