interactive.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. """
  2. 交互式控制器
  3. 提供暂停/继续、交互式菜单、经验总结等功能。
  4. """
  5. import sys
  6. import asyncio
  7. from typing import Optional, Dict, Any
  8. from pathlib import Path
  9. from agent.core.runner import AgentRunner
  10. from agent.trace import TraceStore
  11. # ===== 非阻塞 stdin 检测 =====
  12. if sys.platform == 'win32':
  13. import msvcrt
  14. def check_stdin() -> Optional[str]:
  15. """
  16. 跨平台非阻塞检查 stdin 输入。
  17. Windows: 使用 msvcrt.kbhit()
  18. macOS/Linux: 使用 select.select()
  19. Returns:
  20. 'pause' | 'quit' | None
  21. """
  22. if sys.platform == 'win32':
  23. # Windows: 检查是否有按键按下
  24. if msvcrt.kbhit():
  25. ch = msvcrt.getwch().lower()
  26. if ch == 'p':
  27. return 'pause'
  28. if ch == 'q':
  29. return 'quit'
  30. return None
  31. else:
  32. # Unix/Mac: 使用 select
  33. import select
  34. ready, _, _ = select.select([sys.stdin], [], [], 0)
  35. if ready:
  36. line = sys.stdin.readline().strip().lower()
  37. if line in ('p', 'pause'):
  38. return 'pause'
  39. if line in ('q', 'quit'):
  40. return 'quit'
  41. return None
  42. def read_multiline() -> str:
  43. """
  44. 读取多行输入,以连续两次回车(空行)结束。
  45. Returns:
  46. 用户输入的多行文本
  47. """
  48. print("\n请输入干预消息(连续输入两次回车结束):")
  49. lines = []
  50. blank_count = 0
  51. while True:
  52. line = input()
  53. if line == "":
  54. blank_count += 1
  55. if blank_count >= 2:
  56. break
  57. lines.append("") # 保留单个空行
  58. else:
  59. blank_count = 0
  60. lines.append(line)
  61. # 去掉尾部多余空行
  62. while lines and lines[-1] == "":
  63. lines.pop()
  64. return "\n".join(lines)
  65. # ===== 交互式控制器 =====
  66. class InteractiveController:
  67. """
  68. 交互式控制器
  69. 管理暂停/继续、交互式菜单、经验总结等交互功能。
  70. """
  71. def __init__(
  72. self,
  73. runner: AgentRunner,
  74. store: TraceStore,
  75. enable_stdin_check: bool = True
  76. ):
  77. """
  78. 初始化交互式控制器
  79. Args:
  80. runner: Agent Runner 实例
  81. store: Trace Store 实例
  82. enable_stdin_check: 是否启用 stdin 检查
  83. """
  84. self.runner = runner
  85. self.store = store
  86. self.enable_stdin_check = enable_stdin_check
  87. def check_stdin(self) -> Optional[str]:
  88. """
  89. 检查 stdin 输入
  90. Returns:
  91. 'pause' | 'quit' | None
  92. """
  93. if not self.enable_stdin_check:
  94. return None
  95. return check_stdin()
  96. async def show_menu(
  97. self,
  98. trace_id: str,
  99. current_sequence: int
  100. ) -> Dict[str, Any]:
  101. """
  102. 显示交互式菜单
  103. Args:
  104. trace_id: Trace ID
  105. current_sequence: 当前消息序号
  106. Returns:
  107. 用户选择的操作
  108. """
  109. print("\n" + "=" * 60)
  110. print(" 执行已暂停")
  111. print("=" * 60)
  112. print("请选择操作:")
  113. print(" 1. 插入干预消息并继续")
  114. print(" 2. 触发经验总结(reflect)")
  115. print(" 3. 查看当前 GoalTree")
  116. print(" 4. 手动压缩上下文(compact)")
  117. print(" 5. 从指定消息续跑")
  118. print(" 6. 继续执行")
  119. print(" 7. 停止执行")
  120. print("=" * 60)
  121. while True:
  122. choice = input("请输入选项 (1-7): ").strip()
  123. if choice == "1":
  124. # 插入干预消息
  125. text = read_multiline()
  126. if not text:
  127. print("未输入任何内容,取消操作")
  128. continue
  129. print(f"\n将插入干预消息并继续执行...")
  130. # 从 store 读取实际的 last_sequence
  131. live_trace = await self.store.get_trace(trace_id)
  132. actual_sequence = live_trace.last_sequence if live_trace and live_trace.last_sequence else current_sequence
  133. return {
  134. "action": "continue",
  135. "messages": [{"role": "user", "content": text}],
  136. "after_sequence": actual_sequence,
  137. }
  138. elif choice == "2":
  139. # 触发经验总结
  140. print("\n触发经验总结...")
  141. focus = input("请输入反思重点(可选,直接回车跳过): ").strip()
  142. await self.perform_reflection(trace_id, focus=focus)
  143. continue
  144. elif choice == "3":
  145. # 查看 GoalTree
  146. goal_tree = await self.store.get_goal_tree(trace_id)
  147. if goal_tree and goal_tree.goals:
  148. print("\n当前 GoalTree:")
  149. print(goal_tree.to_prompt())
  150. else:
  151. print("\n当前没有 Goal")
  152. continue
  153. elif choice == "4":
  154. # 手动压缩上下文
  155. await self.manual_compact(trace_id)
  156. continue
  157. elif choice == "5":
  158. # 从指定消息续跑
  159. await self.resume_from_message(trace_id)
  160. return {"action": "stop"} # 返回 stop,让外层循环退出
  161. elif choice == "6":
  162. # 继续执行
  163. print("\n继续执行...")
  164. return {"action": "continue"}
  165. elif choice == "7":
  166. # 停止执行
  167. print("\n停止执行...")
  168. return {"action": "stop"}
  169. else:
  170. print("无效选项,请重新输入")
  171. async def perform_reflection(
  172. self,
  173. trace_id: str,
  174. focus: str = ""
  175. ):
  176. """
  177. 执行经验总结
  178. 通过调用 API 端点触发反思侧分支。
  179. Args:
  180. trace_id: Trace ID
  181. focus: 反思重点(可选)
  182. """
  183. import httpx
  184. print("正在启动反思任务...")
  185. try:
  186. # 调用 reflect API 端点
  187. async with httpx.AsyncClient() as client:
  188. payload = {}
  189. if focus:
  190. payload["focus"] = focus
  191. response = await client.post(
  192. f"http://localhost:8000/api/traces/{trace_id}/reflect",
  193. json=payload,
  194. timeout=10.0
  195. )
  196. response.raise_for_status()
  197. result = response.json()
  198. print(f"✅ 反思任务已启动: {result.get('message', '')}")
  199. print("提示:可通过 WebSocket 监听实时进度")
  200. except httpx.HTTPError as e:
  201. print(f"❌ 反思任务启动失败: {e}")
  202. except Exception as e:
  203. print(f"❌ 发生错误: {e}")
  204. async def manual_compact(self, trace_id: str):
  205. """
  206. 手动压缩上下文
  207. 通过调用 API 端点触发压缩侧分支。
  208. Args:
  209. trace_id: Trace ID
  210. """
  211. import httpx
  212. print("\n正在启动上下文压缩任务...")
  213. try:
  214. # 调用 compact API 端点
  215. async with httpx.AsyncClient() as client:
  216. response = await client.post(
  217. f"http://localhost:8000/api/traces/{trace_id}/compact",
  218. timeout=10.0
  219. )
  220. response.raise_for_status()
  221. result = response.json()
  222. print(f"✅ 压缩任务已启动: {result.get('message', '')}")
  223. print("提示:可通过 WebSocket 监听实时进度")
  224. except httpx.HTTPError as e:
  225. print(f"❌ 压缩任务启动失败: {e}")
  226. except Exception as e:
  227. print(f"❌ 发生错误: {e}")
  228. async def resume_from_message(self, trace_id: str):
  229. """
  230. 从指定消息续跑
  231. 让用户选择一条消息,然后从该消息之后重新执行。
  232. Args:
  233. trace_id: Trace ID
  234. """
  235. print("\n正在加载消息列表...")
  236. # 1. 获取所有消息
  237. messages = await self.store.get_messages(trace_id)
  238. if not messages:
  239. print("❌ 没有找到任何消息")
  240. return
  241. # 2. 显示消息列表(只显示 user 和 assistant 消息)
  242. display_messages = [
  243. msg for msg in messages
  244. if msg.role in ("user", "assistant")
  245. ]
  246. if not display_messages:
  247. print("❌ 没有可选择的消息")
  248. return
  249. print("\n" + "=" * 60)
  250. print(" 消息列表")
  251. print("=" * 60)
  252. for i, msg in enumerate(display_messages, 1):
  253. role_label = "👤 User" if msg.role == "user" else "🤖 Assistant"
  254. content_preview = self._get_content_preview(msg.content)
  255. print(f"{i}. [{msg.sequence:04d}] {role_label}: {content_preview}")
  256. print("=" * 60)
  257. # 3. 让用户选择
  258. while True:
  259. choice = input(f"\n请选择消息编号 (1-{len(display_messages)}),或输入 'c' 取消: ").strip()
  260. if choice.lower() == 'c':
  261. print("已取消")
  262. return
  263. try:
  264. idx = int(choice) - 1
  265. if 0 <= idx < len(display_messages):
  266. selected_msg = display_messages[idx]
  267. break
  268. else:
  269. print(f"无效编号,请输入 1-{len(display_messages)}")
  270. except ValueError:
  271. print("无效输入,请输入数字或 'c'")
  272. # 4. 确认是否重新生成最后一条消息
  273. regenerate_last = False
  274. if selected_msg.role == "assistant":
  275. confirm = input("\n是否重新生成这条 Assistant 消息?(y/n): ").strip().lower()
  276. regenerate_last = (confirm == 'y')
  277. # 5. 调用 runner.run() 续跑
  278. print(f"\n从消息 {selected_msg.sequence:04d} 之后续跑...")
  279. if regenerate_last:
  280. print("将重新生成最后一条 Assistant 消息")
  281. try:
  282. # 加载 trace 和消息历史
  283. trace = await self.store.get_trace(trace_id)
  284. if not trace:
  285. print("❌ Trace 不存在")
  286. return
  287. # 截断消息到指定位置
  288. truncated_messages = []
  289. for msg in messages:
  290. if msg.sequence <= selected_msg.sequence:
  291. truncated_messages.append({
  292. "role": msg.role,
  293. "content": msg.content,
  294. "id": msg.message_id,
  295. })
  296. # 如果需要重新生成,删除最后一条 assistant 消息
  297. if regenerate_last and truncated_messages and truncated_messages[-1]["role"] == "assistant":
  298. truncated_messages.pop()
  299. # 调用 runner.run() 续跑
  300. print("\n开始执行...")
  301. async for event in self.runner.run(
  302. messages=truncated_messages,
  303. trace_id=trace_id,
  304. model=trace.model,
  305. temperature=trace.llm_params.get("temperature", 0.3),
  306. max_iterations=200,
  307. tools=None, # 使用原有配置
  308. ):
  309. # 简单输出事件
  310. if event.get("type") == "message":
  311. msg = event.get("message")
  312. if msg and msg.get("role") == "assistant":
  313. content = msg.get("content", {})
  314. if isinstance(content, dict):
  315. text = content.get("text", "")
  316. else:
  317. text = str(content)
  318. if text:
  319. print(f"\n🤖 Assistant: {text[:200]}...")
  320. print("\n✅ 执行完成")
  321. except Exception as e:
  322. print(f"❌ 执行失败: {e}")
  323. import traceback
  324. traceback.print_exc()
  325. def _get_content_preview(self, content: Any, max_length: int = 60) -> str:
  326. """
  327. 获取消息内容预览
  328. Args:
  329. content: 消息内容
  330. max_length: 最大长度
  331. Returns:
  332. 内容预览字符串
  333. """
  334. if isinstance(content, dict):
  335. text = content.get("text", "")
  336. tool_calls = content.get("tool_calls", [])
  337. if text:
  338. preview = text.strip()
  339. elif tool_calls:
  340. preview = f"[调用工具: {', '.join(tc.get('function', {}).get('name', '?') for tc in tool_calls)}]"
  341. else:
  342. preview = "[空消息]"
  343. elif isinstance(content, str):
  344. preview = content.strip()
  345. else:
  346. preview = str(content)
  347. if len(preview) > max_length:
  348. preview = preview[:max_length] + "..."
  349. return preview