interactive.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. """
  2. 交互式控制器
  3. 提供暂停/继续、交互式菜单、经验总结等功能。
  4. """
  5. import sys
  6. import asyncio
  7. from typing import Optional, Dict, Any
  8. from pathlib import Path
  9. from agent.core.runner import AgentRunner
  10. from agent.trace import TraceStore
  11. # ===== 非阻塞 stdin 检测 =====
  12. if sys.platform == 'win32':
  13. import msvcrt
  14. def check_stdin() -> Optional[str]:
  15. """
  16. 跨平台非阻塞检查 stdin 输入。
  17. Windows: 使用 msvcrt.kbhit()
  18. macOS/Linux: 使用 select.select()
  19. Returns:
  20. 'pause' | 'quit' | None
  21. """
  22. if sys.platform == 'win32':
  23. # Windows: 检查是否有按键按下
  24. if msvcrt.kbhit():
  25. ch = msvcrt.getwch().lower()
  26. if ch == 'p':
  27. return 'pause'
  28. if ch == 'q':
  29. return 'quit'
  30. return None
  31. else:
  32. # Unix/Mac: 使用 select
  33. import select
  34. ready, _, _ = select.select([sys.stdin], [], [], 0)
  35. if ready:
  36. line = sys.stdin.readline().strip().lower()
  37. if line in ('p', 'pause'):
  38. return 'pause'
  39. if line in ('q', 'quit'):
  40. return 'quit'
  41. return None
  42. def read_multiline() -> str:
  43. """
  44. 读取多行输入,以连续两次回车(空行)结束。
  45. Returns:
  46. 用户输入的多行文本
  47. """
  48. print("\n请输入干预消息(连续输入两次回车结束):")
  49. lines = []
  50. blank_count = 0
  51. while True:
  52. line = input()
  53. if line == "":
  54. blank_count += 1
  55. if blank_count >= 2:
  56. break
  57. lines.append("") # 保留单个空行
  58. else:
  59. blank_count = 0
  60. lines.append(line)
  61. # 去掉尾部多余空行
  62. while lines and lines[-1] == "":
  63. lines.pop()
  64. return "\n".join(lines)
  65. # ===== 交互式控制器 =====
  66. class InteractiveController:
  67. """
  68. 交互式控制器
  69. 管理暂停/继续、交互式菜单、经验总结等交互功能。
  70. """
  71. def __init__(
  72. self,
  73. runner: AgentRunner,
  74. store: TraceStore,
  75. enable_stdin_check: bool = True
  76. ):
  77. """
  78. 初始化交互式控制器
  79. Args:
  80. runner: Agent Runner 实例
  81. store: Trace Store 实例
  82. enable_stdin_check: 是否启用 stdin 检查
  83. """
  84. self.runner = runner
  85. self.store = store
  86. self.enable_stdin_check = enable_stdin_check
  87. def check_stdin(self) -> Optional[str]:
  88. """
  89. 检查 stdin 输入
  90. Returns:
  91. 'pause' | 'quit' | None
  92. """
  93. if not self.enable_stdin_check:
  94. return None
  95. return check_stdin()
  96. async def show_menu(
  97. self,
  98. trace_id: str,
  99. current_sequence: int
  100. ) -> Dict[str, Any]:
  101. """
  102. 显示交互式菜单
  103. Args:
  104. trace_id: Trace ID
  105. current_sequence: 当前消息序号
  106. Returns:
  107. 用户选择的操作
  108. """
  109. print("\n" + "=" * 60)
  110. print(" 执行已暂停")
  111. print("=" * 60)
  112. print("请选择操作:")
  113. print(" 1. 插入干预消息并继续")
  114. print(" 2. 触发经验总结(reflect)")
  115. print(" 3. 查看当前 GoalTree")
  116. print(" 4. 手动压缩上下文(compact)")
  117. print(" 5. 继续执行")
  118. print(" 6. 停止执行")
  119. print("=" * 60)
  120. while True:
  121. choice = input("请输入选项 (1-6): ").strip()
  122. if choice == "1":
  123. # 插入干预消息
  124. text = read_multiline()
  125. if not text:
  126. print("未输入任何内容,取消操作")
  127. continue
  128. print(f"\n将插入干预消息并继续执行...")
  129. # 从 store 读取实际的 last_sequence
  130. live_trace = await self.store.get_trace(trace_id)
  131. actual_sequence = live_trace.last_sequence if live_trace and live_trace.last_sequence else current_sequence
  132. return {
  133. "action": "continue",
  134. "messages": [{"role": "user", "content": text}],
  135. "after_sequence": actual_sequence,
  136. }
  137. elif choice == "2":
  138. # 触发经验总结
  139. print("\n触发经验总结...")
  140. focus = input("请输入反思重点(可选,直接回车跳过): ").strip()
  141. await self.perform_reflection(trace_id, focus=focus)
  142. continue
  143. elif choice == "3":
  144. # 查看 GoalTree
  145. goal_tree = await self.store.get_goal_tree(trace_id)
  146. if goal_tree and goal_tree.goals:
  147. print("\n当前 GoalTree:")
  148. print(goal_tree.to_prompt())
  149. else:
  150. print("\n当前没有 Goal")
  151. continue
  152. elif choice == "4":
  153. # 手动压缩上下文
  154. await self.manual_compact(trace_id)
  155. continue
  156. elif choice == "5":
  157. # 继续执行
  158. print("\n继续执行...")
  159. return {"action": "continue"}
  160. elif choice == "6":
  161. # 停止执行
  162. print("\n停止执行...")
  163. return {"action": "stop"}
  164. else:
  165. print("无效选项,请重新输入")
  166. async def perform_reflection(
  167. self,
  168. trace_id: str,
  169. focus: str = ""
  170. ):
  171. """
  172. 执行经验总结
  173. 通过调用 API 端点触发反思侧分支。
  174. Args:
  175. trace_id: Trace ID
  176. focus: 反思重点(可选)
  177. """
  178. import httpx
  179. print("正在启动反思任务...")
  180. try:
  181. # 调用 reflect API 端点
  182. async with httpx.AsyncClient() as client:
  183. payload = {}
  184. if focus:
  185. payload["focus"] = focus
  186. response = await client.post(
  187. f"http://localhost:8000/api/traces/{trace_id}/reflect",
  188. json=payload,
  189. timeout=10.0
  190. )
  191. response.raise_for_status()
  192. result = response.json()
  193. print(f"✅ 反思任务已启动: {result.get('message', '')}")
  194. print("提示:可通过 WebSocket 监听实时进度")
  195. except httpx.HTTPError as e:
  196. print(f"❌ 反思任务启动失败: {e}")
  197. except Exception as e:
  198. print(f"❌ 发生错误: {e}")
  199. async def manual_compact(self, trace_id: str):
  200. """
  201. 手动压缩上下文
  202. 通过调用 API 端点触发压缩侧分支。
  203. Args:
  204. trace_id: Trace ID
  205. """
  206. import httpx
  207. print("\n正在启动上下文压缩任务...")
  208. try:
  209. # 调用 compact API 端点
  210. async with httpx.AsyncClient() as client:
  211. response = await client.post(
  212. f"http://localhost:8000/api/traces/{trace_id}/compact",
  213. timeout=10.0
  214. )
  215. response.raise_for_status()
  216. result = response.json()
  217. print(f"✅ 压缩任务已启动: {result.get('message', '')}")
  218. print("提示:可通过 WebSocket 监听实时进度")
  219. except httpx.HTTPError as e:
  220. print(f"❌ 压缩任务启动失败: {e}")
  221. except Exception as e:
  222. print(f"❌ 发生错误: {e}")