| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427 |
- """
- 交互式控制器
- 提供暂停/继续、交互式菜单、经验总结等功能。
- """
- import sys
- import asyncio
- from typing import Optional, Dict, Any
- from pathlib import Path
- from agent.core.runner import AgentRunner
- from agent.trace import TraceStore
- # ===== 非阻塞 stdin 检测 =====
- if sys.platform == 'win32':
- import msvcrt
- def check_stdin() -> Optional[str]:
- """
- 跨平台非阻塞检查 stdin 输入。
- Windows: 使用 msvcrt.kbhit()
- macOS/Linux: 使用 select.select()
- Returns:
- 'pause' | 'quit' | None
- """
- if sys.platform == 'win32':
- # Windows: 检查是否有按键按下
- if msvcrt.kbhit():
- ch = msvcrt.getwch().lower()
- if ch == 'p':
- return 'pause'
- if ch == 'q':
- return 'quit'
- return None
- else:
- # Unix/Mac: 使用 select
- import select
- ready, _, _ = select.select([sys.stdin], [], [], 0)
- if ready:
- line = sys.stdin.readline().strip().lower()
- if line in ('p', 'pause'):
- return 'pause'
- if line in ('q', 'quit'):
- return 'quit'
- return None
- def read_multiline() -> str:
- """
- 读取多行输入,以连续两次回车(空行)结束。
- Returns:
- 用户输入的多行文本
- """
- print("\n请输入干预消息(连续输入两次回车结束):")
- lines = []
- blank_count = 0
- while True:
- line = input()
- if line == "":
- blank_count += 1
- if blank_count >= 2:
- break
- lines.append("") # 保留单个空行
- else:
- blank_count = 0
- lines.append(line)
- # 去掉尾部多余空行
- while lines and lines[-1] == "":
- lines.pop()
- return "\n".join(lines)
- # ===== 交互式控制器 =====
- class InteractiveController:
- """
- 交互式控制器
- 管理暂停/继续、交互式菜单、经验总结等交互功能。
- """
- def __init__(
- self,
- runner: AgentRunner,
- store: TraceStore,
- enable_stdin_check: bool = True
- ):
- """
- 初始化交互式控制器
- Args:
- runner: Agent Runner 实例
- store: Trace Store 实例
- enable_stdin_check: 是否启用 stdin 检查
- """
- self.runner = runner
- self.store = store
- self.enable_stdin_check = enable_stdin_check
- def check_stdin(self) -> Optional[str]:
- """
- 检查 stdin 输入
- Returns:
- 'pause' | 'quit' | None
- """
- if not self.enable_stdin_check:
- return None
- return check_stdin()
- async def show_menu(
- self,
- trace_id: str,
- current_sequence: int
- ) -> Dict[str, Any]:
- """
- 显示交互式菜单
- Args:
- trace_id: Trace ID
- current_sequence: 当前消息序号
- Returns:
- 用户选择的操作
- """
- print("\n" + "=" * 60)
- print(" 执行已暂停")
- print("=" * 60)
- print("请选择操作:")
- print(" 1. 插入干预消息并继续")
- print(" 2. 触发经验总结(reflect)")
- print(" 3. 查看当前 GoalTree")
- print(" 4. 手动压缩上下文(compact)")
- print(" 5. 从指定消息续跑")
- print(" 6. 继续执行")
- print(" 7. 停止执行")
- print("=" * 60)
- while True:
- choice = input("请输入选项 (1-7): ").strip()
- if choice == "1":
- # 插入干预消息
- text = read_multiline()
- if not text:
- print("未输入任何内容,取消操作")
- continue
- print(f"\n将插入干预消息并继续执行...")
- # 从 store 读取实际的 last_sequence
- live_trace = await self.store.get_trace(trace_id)
- actual_sequence = live_trace.last_sequence if live_trace and live_trace.last_sequence else current_sequence
- return {
- "action": "continue",
- "messages": [{"role": "user", "content": text}],
- "after_sequence": actual_sequence,
- }
- elif choice == "2":
- # 触发经验总结
- print("\n触发经验总结...")
- focus = input("请输入反思重点(可选,直接回车跳过): ").strip()
- await self.perform_reflection(trace_id, focus=focus)
- continue
- elif choice == "3":
- # 查看 GoalTree
- goal_tree = await self.store.get_goal_tree(trace_id)
- if goal_tree and goal_tree.goals:
- print("\n当前 GoalTree:")
- print(goal_tree.to_prompt())
- else:
- print("\n当前没有 Goal")
- continue
- elif choice == "4":
- # 手动压缩上下文
- await self.manual_compact(trace_id)
- continue
- elif choice == "5":
- # 从指定消息续跑
- await self.resume_from_message(trace_id)
- return {"action": "stop"} # 返回 stop,让外层循环退出
- elif choice == "6":
- # 继续执行
- print("\n继续执行...")
- return {"action": "continue"}
- elif choice == "7":
- # 停止执行
- print("\n停止执行...")
- return {"action": "stop"}
- else:
- print("无效选项,请重新输入")
- async def perform_reflection(
- self,
- trace_id: str,
- focus: str = ""
- ):
- """
- 执行经验总结
- 通过调用 API 端点触发反思侧分支。
- Args:
- trace_id: Trace ID
- focus: 反思重点(可选)
- """
- import httpx
- print("正在启动反思任务...")
- try:
- # 调用 reflect API 端点
- async with httpx.AsyncClient() as client:
- payload = {}
- if focus:
- payload["focus"] = focus
- response = await client.post(
- f"http://localhost:8000/api/traces/{trace_id}/reflect",
- json=payload,
- timeout=10.0
- )
- response.raise_for_status()
- result = response.json()
- print(f"✅ 反思任务已启动: {result.get('message', '')}")
- print("提示:可通过 WebSocket 监听实时进度")
- except httpx.HTTPError as e:
- print(f"❌ 反思任务启动失败: {e}")
- except Exception as e:
- print(f"❌ 发生错误: {e}")
- async def manual_compact(self, trace_id: str):
- """
- 手动压缩上下文
- 通过调用 API 端点触发压缩侧分支。
- Args:
- trace_id: Trace ID
- """
- import httpx
- print("\n正在启动上下文压缩任务...")
- try:
- # 调用 compact API 端点
- async with httpx.AsyncClient() as client:
- response = await client.post(
- f"http://localhost:8000/api/traces/{trace_id}/compact",
- timeout=10.0
- )
- response.raise_for_status()
- result = response.json()
- print(f"✅ 压缩任务已启动: {result.get('message', '')}")
- print("提示:可通过 WebSocket 监听实时进度")
- except httpx.HTTPError as e:
- print(f"❌ 压缩任务启动失败: {e}")
- except Exception as e:
- print(f"❌ 发生错误: {e}")
- async def resume_from_message(self, trace_id: str):
- """
- 从指定消息续跑
- 让用户选择一条消息,然后从该消息之后重新执行。
- Args:
- trace_id: Trace ID
- """
- print("\n正在加载消息列表...")
- # 1. 获取所有消息
- messages = await self.store.get_messages(trace_id)
- if not messages:
- print("❌ 没有找到任何消息")
- return
- # 2. 显示消息列表(只显示 user 和 assistant 消息)
- display_messages = [
- msg for msg in messages
- if msg.role in ("user", "assistant")
- ]
- if not display_messages:
- print("❌ 没有可选择的消息")
- return
- print("\n" + "=" * 60)
- print(" 消息列表")
- print("=" * 60)
- for i, msg in enumerate(display_messages, 1):
- role_label = "👤 User" if msg.role == "user" else "🤖 Assistant"
- content_preview = self._get_content_preview(msg.content)
- print(f"{i}. [{msg.sequence:04d}] {role_label}: {content_preview}")
- print("=" * 60)
- # 3. 让用户选择
- while True:
- choice = input(f"\n请选择消息编号 (1-{len(display_messages)}),或输入 'c' 取消: ").strip()
- if choice.lower() == 'c':
- print("已取消")
- return
- try:
- idx = int(choice) - 1
- if 0 <= idx < len(display_messages):
- selected_msg = display_messages[idx]
- break
- else:
- print(f"无效编号,请输入 1-{len(display_messages)}")
- except ValueError:
- print("无效输入,请输入数字或 'c'")
- # 4. 确认是否重新生成最后一条消息
- regenerate_last = False
- if selected_msg.role == "assistant":
- confirm = input("\n是否重新生成这条 Assistant 消息?(y/n): ").strip().lower()
- regenerate_last = (confirm == 'y')
- # 5. 调用 runner.run() 续跑
- print(f"\n从消息 {selected_msg.sequence:04d} 之后续跑...")
- if regenerate_last:
- print("将重新生成最后一条 Assistant 消息")
- try:
- # 加载 trace 和消息历史
- trace = await self.store.get_trace(trace_id)
- if not trace:
- print("❌ Trace 不存在")
- return
- # 截断消息到指定位置
- truncated_messages = []
- for msg in messages:
- if msg.sequence <= selected_msg.sequence:
- truncated_messages.append({
- "role": msg.role,
- "content": msg.content,
- "id": msg.message_id,
- })
- # 如果需要重新生成,删除最后一条 assistant 消息
- if regenerate_last and truncated_messages and truncated_messages[-1]["role"] == "assistant":
- truncated_messages.pop()
- # 调用 runner.run() 续跑
- print("\n开始执行...")
- async for event in self.runner.run(
- messages=truncated_messages,
- trace_id=trace_id,
- model=trace.model,
- temperature=trace.llm_params.get("temperature", 0.3),
- max_iterations=200,
- tools=None, # 使用原有配置
- ):
- # 简单输出事件
- if event.get("type") == "message":
- msg = event.get("message")
- if msg and msg.get("role") == "assistant":
- content = msg.get("content", {})
- if isinstance(content, dict):
- text = content.get("text", "")
- else:
- text = str(content)
- if text:
- print(f"\n🤖 Assistant: {text[:200]}...")
- print("\n✅ 执行完成")
- except Exception as e:
- print(f"❌ 执行失败: {e}")
- import traceback
- traceback.print_exc()
- def _get_content_preview(self, content: Any, max_length: int = 60) -> str:
- """
- 获取消息内容预览
- Args:
- content: 消息内容
- max_length: 最大长度
- Returns:
- 内容预览字符串
- """
- if isinstance(content, dict):
- text = content.get("text", "")
- tool_calls = content.get("tool_calls", [])
- if text:
- preview = text.strip()
- elif tool_calls:
- preview = f"[调用工具: {', '.join(tc.get('function', {}).get('name', '?') for tc in tool_calls)}]"
- else:
- preview = "[空消息]"
- elif isinstance(content, str):
- preview = content.strip()
- else:
- preview = str(content)
- if len(preview) > max_length:
- preview = preview[:max_length] + "..."
- return preview
|