interactive.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. """
  2. 交互式控制器
  3. 提供暂停/继续、交互式菜单、经验总结等功能。
  4. """
  5. import sys
  6. import asyncio
  7. from typing import Optional, Dict, Any
  8. from pathlib import Path
  9. from agent.core.runner import AgentRunner
  10. from agent.trace import TraceStore
  11. # ===== 非阻塞 stdin 检测 =====
  12. if sys.platform == 'win32':
  13. import msvcrt
  14. def check_stdin() -> Optional[str]:
  15. """
  16. 跨平台非阻塞检查 stdin 输入。
  17. 支持终端输入和控制文件(用于后台进程控制)。
  18. 优先级:
  19. 1. 检查控制文件 .agent_control(用于后台进程)
  20. 2. 检查终端/管道输入
  21. Returns:
  22. 'pause' | 'quit' | None
  23. """
  24. # 1. 优先检查控制文件(用于后台进程控制)
  25. control_file = Path.cwd() / ".agent_control"
  26. if control_file.exists():
  27. try:
  28. cmd = control_file.read_text(encoding='utf-8').strip().lower()
  29. control_file.unlink() # 读取后立即删除
  30. if cmd in ('p', 'pause'):
  31. return 'pause'
  32. if cmd in ('q', 'quit'):
  33. return 'quit'
  34. except Exception:
  35. pass
  36. # 2. 检查终端/管道输入
  37. if sys.platform == 'win32':
  38. # Windows: 先检查是否是终端
  39. if sys.stdin.isatty():
  40. # 终端模式:使用 msvcrt
  41. if msvcrt.kbhit():
  42. ch = msvcrt.getwch().lower()
  43. if ch == 'p':
  44. return 'pause'
  45. if ch == 'q':
  46. return 'quit'
  47. else:
  48. # 管道模式:尝试非阻塞读取
  49. try:
  50. import select
  51. ready, _, _ = select.select([sys.stdin], [], [], 0)
  52. if ready:
  53. line = sys.stdin.readline().strip().lower()
  54. if line in ('p', 'pause'):
  55. return 'pause'
  56. if line in ('q', 'quit'):
  57. return 'quit'
  58. except Exception:
  59. pass
  60. return None
  61. else:
  62. # Unix/Mac: 使用 select(支持终端和管道)
  63. import select
  64. ready, _, _ = select.select([sys.stdin], [], [], 0)
  65. if ready:
  66. line = sys.stdin.readline().strip().lower()
  67. if line in ('p', 'pause'):
  68. return 'pause'
  69. if line in ('q', 'quit'):
  70. return 'quit'
  71. return None
  72. def read_multiline() -> str:
  73. """
  74. 读取多行输入,以连续两次回车(空行)结束。
  75. Returns:
  76. 用户输入的多行文本
  77. """
  78. print("\n请输入干预消息(连续输入两次回车结束):")
  79. lines = []
  80. blank_count = 0
  81. while True:
  82. line = input()
  83. if line == "":
  84. blank_count += 1
  85. if blank_count >= 2:
  86. break
  87. lines.append("") # 保留单个空行
  88. else:
  89. blank_count = 0
  90. lines.append(line)
  91. # 去掉尾部多余空行
  92. while lines and lines[-1] == "":
  93. lines.pop()
  94. return "\n".join(lines)
  95. # ===== 交互式控制器 =====
  96. class InteractiveController:
  97. """
  98. 交互式控制器
  99. 管理暂停/继续、交互式菜单、经验总结等交互功能。
  100. """
  101. def __init__(
  102. self,
  103. runner: AgentRunner,
  104. store: TraceStore,
  105. enable_stdin_check: bool = True
  106. ):
  107. """
  108. 初始化交互式控制器
  109. Args:
  110. runner: Agent Runner 实例
  111. store: Trace Store 实例
  112. enable_stdin_check: 是否启用 stdin 检查
  113. """
  114. self.runner = runner
  115. self.store = store
  116. self.enable_stdin_check = enable_stdin_check
  117. def check_stdin(self) -> Optional[str]:
  118. """
  119. 检查 stdin 输入
  120. Returns:
  121. 'pause' | 'quit' | None
  122. """
  123. if not self.enable_stdin_check:
  124. return None
  125. return check_stdin()
  126. async def show_menu(
  127. self,
  128. trace_id: str,
  129. current_sequence: int
  130. ) -> Dict[str, Any]:
  131. """
  132. 显示交互式菜单
  133. Args:
  134. trace_id: Trace ID
  135. current_sequence: 当前消息序号
  136. Returns:
  137. 用户选择的操作
  138. """
  139. print("\n" + "=" * 60)
  140. print(" 执行已暂停")
  141. print("=" * 60)
  142. print("请选择操作:")
  143. print(" 1. 插入干预消息并继续")
  144. print(" 2. 触发经验总结(reflect)")
  145. print(" 3. 查看当前 GoalTree")
  146. print(" 4. 手动压缩上下文(compact)")
  147. print(" 5. 从指定消息续跑")
  148. print(" 6. 继续执行")
  149. print(" 7. 停止执行")
  150. print("=" * 60)
  151. while True:
  152. choice = input("请输入选项 (1-7): ").strip()
  153. if choice == "1":
  154. # 插入干预消息
  155. text = read_multiline()
  156. if not text:
  157. print("未输入任何内容,取消操作")
  158. continue
  159. print(f"\n将插入干预消息并继续执行...")
  160. # 从 store 读取实际的 last_sequence
  161. live_trace = await self.store.get_trace(trace_id)
  162. actual_sequence = live_trace.last_sequence if live_trace and live_trace.last_sequence else current_sequence
  163. return {
  164. "action": "continue",
  165. "messages": [{"role": "user", "content": text}],
  166. "after_sequence": actual_sequence,
  167. }
  168. elif choice == "2":
  169. # 触发经验总结
  170. print("\n触发经验总结...")
  171. focus = input("请输入反思重点(可选,直接回车跳过): ").strip()
  172. await self.perform_reflection(trace_id, focus=focus)
  173. continue
  174. elif choice == "3":
  175. # 查看 GoalTree
  176. goal_tree = await self.store.get_goal_tree(trace_id)
  177. if goal_tree and goal_tree.goals:
  178. print("\n当前 GoalTree:")
  179. print(goal_tree.to_prompt())
  180. else:
  181. print("\n当前没有 Goal")
  182. continue
  183. elif choice == "4":
  184. # 手动压缩上下文
  185. await self.manual_compact(trace_id)
  186. continue
  187. elif choice == "5":
  188. # 从指定消息续跑
  189. await self.resume_from_message(trace_id)
  190. return {"action": "stop"} # 返回 stop,让外层循环退出
  191. elif choice == "6":
  192. # 继续执行
  193. print("\n继续执行...")
  194. return {"action": "continue"}
  195. elif choice == "7":
  196. # 停止执行
  197. print("\n停止执行...")
  198. return {"action": "stop"}
  199. else:
  200. print("无效选项,请重新输入")
  201. async def perform_reflection(
  202. self,
  203. trace_id: str,
  204. focus: str = ""
  205. ):
  206. """
  207. 执行经验总结
  208. 通过调用 API 端点触发反思侧分支。
  209. Args:
  210. trace_id: Trace ID
  211. focus: 反思重点(可选)
  212. """
  213. import httpx
  214. print("正在启动反思任务...")
  215. try:
  216. # 调用 reflect API 端点
  217. async with httpx.AsyncClient() as client:
  218. payload = {}
  219. if focus:
  220. payload["focus"] = focus
  221. response = await client.post(
  222. f"http://localhost:8000/api/traces/{trace_id}/reflect",
  223. json=payload,
  224. timeout=10.0
  225. )
  226. response.raise_for_status()
  227. result = response.json()
  228. print(f"✅ 反思任务已启动: {result.get('message', '')}")
  229. print("提示:可通过 WebSocket 监听实时进度")
  230. except httpx.HTTPError as e:
  231. print(f"❌ 反思任务启动失败: {e}")
  232. except Exception as e:
  233. print(f"❌ 发生错误: {e}")
  234. async def manual_compact(self, trace_id: str):
  235. """
  236. 手动压缩上下文
  237. 通过调用 API 端点触发压缩侧分支。
  238. Args:
  239. trace_id: Trace ID
  240. """
  241. import httpx
  242. print("\n正在启动上下文压缩任务...")
  243. try:
  244. # 调用 compact API 端点
  245. async with httpx.AsyncClient() as client:
  246. response = await client.post(
  247. f"http://localhost:8000/api/traces/{trace_id}/compact",
  248. timeout=10.0
  249. )
  250. response.raise_for_status()
  251. result = response.json()
  252. print(f"✅ 压缩任务已启动: {result.get('message', '')}")
  253. print("提示:可通过 WebSocket 监听实时进度")
  254. except httpx.HTTPError as e:
  255. print(f"❌ 压缩任务启动失败: {e}")
  256. except Exception as e:
  257. print(f"❌ 发生错误: {e}")
  258. async def resume_from_message(self, trace_id: str):
  259. """
  260. 从指定消息续跑
  261. 让用户选择一条消息,然后从该消息之后重新执行。
  262. Args:
  263. trace_id: Trace ID
  264. """
  265. print("\n正在加载消息列表...")
  266. # 1. 获取所有消息
  267. messages = await self.store.get_trace_messages(trace_id)
  268. if not messages:
  269. print("❌ 没有找到任何消息")
  270. return
  271. # 2. 显示消息列表(只显示 user 和 assistant 消息)
  272. display_messages = [
  273. msg for msg in messages
  274. if msg.role in ("user", "assistant")
  275. ]
  276. if not display_messages:
  277. print("❌ 没有可选择的消息")
  278. return
  279. print("\n" + "=" * 60)
  280. print(" 消息列表")
  281. print("=" * 60)
  282. for i, msg in enumerate(display_messages, 1):
  283. role_label = "👤 User" if msg.role == "user" else "🤖 Assistant"
  284. content_preview = self._get_content_preview(msg.content)
  285. print(f"{i}. [{msg.sequence:04d}] {role_label}: {content_preview}")
  286. print("=" * 60)
  287. # 3. 让用户选择
  288. while True:
  289. choice = input(f"\n请选择消息编号 (1-{len(display_messages)}),或输入 'c' 取消: ").strip()
  290. if choice.lower() == 'c':
  291. print("已取消")
  292. return
  293. try:
  294. idx = int(choice) - 1
  295. if 0 <= idx < len(display_messages):
  296. selected_msg = display_messages[idx]
  297. break
  298. else:
  299. print(f"无效编号,请输入 1-{len(display_messages)}")
  300. except ValueError:
  301. print("无效输入,请输入数字或 'c'")
  302. # 4. 确认是否重新生成最后一条消息
  303. regenerate_last = False
  304. if selected_msg.role == "assistant":
  305. confirm = input("\n是否重新生成这条 Assistant 消息?(y/n): ").strip().lower()
  306. regenerate_last = (confirm == 'y')
  307. # 5. 调用 runner.run() 续跑
  308. print(f"\n从消息 {selected_msg.sequence:04d} 之后续跑...")
  309. if regenerate_last:
  310. print("将重新生成最后一条 Assistant 消息")
  311. try:
  312. # 加载 trace 和消息历史
  313. trace = await self.store.get_trace(trace_id)
  314. if not trace:
  315. print("❌ Trace 不存在")
  316. return
  317. # 截断消息到指定位置
  318. truncated_messages = []
  319. for msg in messages:
  320. if msg.sequence <= selected_msg.sequence:
  321. truncated_messages.append({
  322. "role": msg.role,
  323. "content": msg.content,
  324. "id": msg.message_id,
  325. })
  326. # 如果需要重新生成,删除最后一条 assistant 消息
  327. if regenerate_last and truncated_messages and truncated_messages[-1]["role"] == "assistant":
  328. truncated_messages.pop()
  329. # 调用 runner.run() 续跑
  330. print("\n开始执行...")
  331. async for event in self.runner.run(
  332. messages=truncated_messages,
  333. trace_id=trace_id,
  334. model=trace.model,
  335. temperature=trace.llm_params.get("temperature", 0.3),
  336. max_iterations=200,
  337. tools=None, # 使用原有配置
  338. ):
  339. # 简单输出事件
  340. if event.get("type") == "message":
  341. msg = event.get("message")
  342. if msg and msg.get("role") == "assistant":
  343. content = msg.get("content", {})
  344. if isinstance(content, dict):
  345. text = content.get("text", "")
  346. else:
  347. text = str(content)
  348. if text:
  349. print(f"\n🤖 Assistant: {text[:200]}...")
  350. print("\n✅ 执行完成")
  351. except Exception as e:
  352. print(f"❌ 执行失败: {e}")
  353. import traceback
  354. traceback.print_exc()
  355. def _get_content_preview(self, content: Any, max_length: int = 60) -> str:
  356. """
  357. 获取消息内容预览
  358. Args:
  359. content: 消息内容
  360. max_length: 最大长度
  361. Returns:
  362. 内容预览字符串
  363. """
  364. if isinstance(content, dict):
  365. text = content.get("text", "")
  366. tool_calls = content.get("tool_calls", [])
  367. if text:
  368. preview = text.strip()
  369. elif tool_calls:
  370. preview = f"[调用工具: {', '.join(tc.get('function', {}).get('name', '?') for tc in tool_calls)}]"
  371. else:
  372. preview = "[空消息]"
  373. elif isinstance(content, str):
  374. preview = content.strip()
  375. else:
  376. preview = str(content)
  377. if len(preview) > max_length:
  378. preview = preview[:max_length] + "..."
  379. return preview