overall_derivation_agent_run.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573
  1. """
  2. 选题点整体推导 Agent(增强版)
  3. 参考 examples/how/run.py,提供:
  4. 1. 命令行交互:输入 'p' 暂停、'q' 退出
  5. 2. 暂停后可插入干预消息、触发经验总结、查看 GoalTree、手动压缩上下文
  6. 3. 支持 --trace <ID> 恢复已有 Trace 继续执行
  7. 4. 使用 SimplePrompt 加载 derivation_main.md,支持评估子 agent(agent_type=evaluate_derivation)
  8. """
  9. import argparse
  10. import os
  11. import sys
  12. import select
  13. import asyncio
  14. from datetime import datetime
  15. from pathlib import Path
  16. # 与 examples/how/run.py 一致:禁止 httpx/urllib 自动检测系统 HTTP 代理
  17. # os.environ.setdefault("no_proxy", "*")
  18. # 项目根目录(兼容从项目根或脚本目录启动)
  19. PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
  20. sys.path.insert(0, str(PROJECT_ROOT))
  21. from dotenv import load_dotenv
  22. from agent.llm.prompts import SimplePrompt
  23. from agent.core.runner import AgentRunner, RunConfig
  24. from agent.core.presets import AgentPreset, register_preset
  25. from agent.trace import (
  26. FileSystemTraceStore,
  27. Trace,
  28. Message,
  29. )
  30. from agent.llm import create_openrouter_llm_call
  31. from agent.trace.compaction import build_reflect_prompt
  32. load_dotenv()
  33. # 导入项目配置
  34. from config import RUN_CONFIG, SKILLS_DIR, TRACE_STORE_PATH, DEBUG, LOG_LEVEL, LOG_FILE, BROWSER_TYPE, HEADLESS
  35. # ===== 非阻塞 stdin 检测 =====
  36. if sys.platform == 'win32':
  37. import msvcrt
  38. def check_stdin() -> str | None:
  39. """
  40. 跨平台非阻塞检查 stdin 输入。
  41. Windows: msvcrt.kbhit();macOS/Linux: select.select()
  42. """
  43. if sys.platform == 'win32':
  44. if msvcrt.kbhit():
  45. ch = msvcrt.getwch().lower()
  46. if ch == 'p':
  47. return 'pause'
  48. if ch == 'q':
  49. return 'quit'
  50. return None
  51. else:
  52. ready, _, _ = select.select([sys.stdin], [], [], 0)
  53. if ready:
  54. line = sys.stdin.readline().strip().lower()
  55. if line in ('p', 'pause'):
  56. return 'pause'
  57. if line in ('q', 'quit'):
  58. return 'quit'
  59. return None
  60. # ===== 交互菜单 =====
  61. def _read_multiline() -> str:
  62. """读取多行输入,以连续两次回车(空行)结束。"""
  63. print("\n请输入干预消息(连续输入两次回车结束):")
  64. lines: list[str] = []
  65. blank_count = 0
  66. while True:
  67. line = input()
  68. if line == "":
  69. blank_count += 1
  70. if blank_count >= 2:
  71. break
  72. lines.append("")
  73. else:
  74. blank_count = 0
  75. lines.append(line)
  76. while lines and lines[-1] == "":
  77. lines.pop()
  78. return "\n".join(lines)
  79. async def show_interactive_menu(
  80. runner: AgentRunner,
  81. trace_id: str,
  82. current_sequence: int,
  83. store: FileSystemTraceStore,
  84. ):
  85. """显示交互式菜单,让用户选择操作。"""
  86. print("\n" + "=" * 60)
  87. print(" 执行已暂停")
  88. print("=" * 60)
  89. print("请选择操作:")
  90. print(" 1. 插入干预消息并继续")
  91. print(" 2. 触发经验总结(reflect)")
  92. print(" 3. 查看当前 GoalTree")
  93. print(" 4. 手动压缩上下文(compact)")
  94. print(" 5. 继续执行")
  95. print(" 6. 停止执行")
  96. print("=" * 60)
  97. while True:
  98. choice = input("请输入选项 (1-6): ").strip()
  99. if choice == "1":
  100. text = _read_multiline()
  101. if not text:
  102. print("未输入任何内容,取消操作")
  103. continue
  104. print("\n将插入干预消息并继续执行...")
  105. live_trace = await store.get_trace(trace_id)
  106. actual_sequence = live_trace.last_sequence if live_trace and live_trace.last_sequence else current_sequence
  107. return {
  108. "action": "continue",
  109. "messages": [{"role": "user", "content": text}],
  110. "after_sequence": actual_sequence,
  111. }
  112. elif choice == "2":
  113. print("\n触发经验总结...")
  114. focus = input("请输入反思重点(可选,直接回车跳过): ").strip()
  115. trace = await store.get_trace(trace_id)
  116. saved_head = trace.head_sequence
  117. prompt = build_reflect_prompt()
  118. if focus:
  119. prompt += f"\n\n请特别关注:{focus}"
  120. print("正在生成反思...")
  121. reflect_cfg = RunConfig(trace_id=trace_id, max_iterations=1, tools=[])
  122. reflection_text = ""
  123. try:
  124. result = await runner.run_result(
  125. messages=[{"role": "user", "content": prompt}],
  126. config=reflect_cfg,
  127. )
  128. reflection_text = result.get("summary", "")
  129. finally:
  130. await store.update_trace(trace_id, head_sequence=saved_head)
  131. if reflection_text:
  132. from datetime import datetime
  133. experiences_path = runner.experiences_path or "./.cache/experiences_overall_derivation.md"
  134. os.makedirs(os.path.dirname(experiences_path), exist_ok=True)
  135. header = f"\n\n---\n\n## {trace_id} ({datetime.now().strftime('%Y-%m-%d %H:%M')})\n\n"
  136. with open(experiences_path, "a", encoding="utf-8") as f:
  137. f.write(header + reflection_text + "\n")
  138. print(f"\n反思已保存到: {experiences_path}")
  139. print("\n--- 反思内容 ---")
  140. print(reflection_text)
  141. print("--- 结束 ---\n")
  142. else:
  143. print("未生成反思内容")
  144. continue
  145. elif choice == "3":
  146. goal_tree = await store.get_goal_tree(trace_id)
  147. if goal_tree and goal_tree.goals:
  148. print("\n当前 GoalTree:")
  149. print(goal_tree.to_prompt())
  150. else:
  151. print("\n当前没有 Goal")
  152. continue
  153. elif choice == "4":
  154. print("\n正在执行上下文压缩(compact)...")
  155. try:
  156. goal_tree = await store.get_goal_tree(trace_id)
  157. trace = await store.get_trace(trace_id)
  158. if not trace:
  159. print("未找到 Trace,无法压缩")
  160. continue
  161. main_path = await store.get_main_path_messages(trace_id, trace.head_sequence)
  162. history = [msg.to_llm_dict() for msg in main_path]
  163. head_seq = main_path[-1].sequence if main_path else 0
  164. next_seq = head_seq + 1
  165. compact_config = RunConfig(trace_id=trace_id)
  166. new_history, new_head, new_seq = await runner._compress_history(
  167. trace_id=trace_id,
  168. history=history,
  169. goal_tree=goal_tree,
  170. config=compact_config,
  171. sequence=next_seq,
  172. head_seq=head_seq,
  173. )
  174. print(f"\n✅ 压缩完成: {len(history)} 条消息 → {len(new_history)} 条")
  175. except Exception as e:
  176. print(f"\n❌ 压缩失败: {e}")
  177. continue
  178. elif choice == "5":
  179. print("\n继续执行...")
  180. return {"action": "continue"}
  181. elif choice == "6":
  182. print("\n停止执行...")
  183. return {"action": "stop"}
  184. else:
  185. print("无效选项,请重新输入")
  186. def _replace_prompt_placeholders(
  187. messages: list,
  188. account_name: str,
  189. post_id: str,
  190. log_id: str,
  191. post_point_count: int,
  192. account_tree_data: str,
  193. ) -> None:
  194. """在 messages 的 content 中用 replace 替换 prompt 占位符。"""
  195. post_point_count_str = str(post_point_count)
  196. for m in messages:
  197. content = m.get("content")
  198. if isinstance(content, str):
  199. m["content"] = (
  200. content.replace("{account_name}", account_name)
  201. .replace("{帖子ID}", post_id)
  202. .replace("{log_id}", log_id)
  203. .replace("{post_point_count}", post_point_count_str)
  204. .replace("{account_tree_data}", account_tree_data)
  205. )
  206. elif isinstance(content, list):
  207. for part in content:
  208. if isinstance(part, dict) and part.get("type") == "text":
  209. part["text"] = (
  210. (part.get("text") or "")
  211. .replace("{account_name}", account_name)
  212. .replace("{帖子ID}", post_id)
  213. .replace("{log_id}", log_id)
  214. .replace("{post_point_count}", post_point_count_str)
  215. .replace("{account_tree_data}", account_tree_data)
  216. )
  217. async def main(account_name, post_id):
  218. parser = argparse.ArgumentParser(description="选题点整体推导 Agent(支持交互与恢复)")
  219. parser.add_argument(
  220. "--trace", type=str, default=None,
  221. help="已有的 Trace ID,用于恢复继续执行(不指定则新建)",
  222. )
  223. args = parser.parse_args()
  224. base_dir = Path(__file__).parent
  225. prompt_path = base_dir / "derivation_main.md"
  226. # 加载项目级 presets(evaluate_derivation、derivation_search 等)
  227. presets_path = base_dir / "presets.json"
  228. if presets_path.exists():
  229. import json
  230. with open(presets_path, "r", encoding="utf-8") as f:
  231. project_presets = json.load(f)
  232. for name, cfg in project_presets.items():
  233. register_preset(name, AgentPreset(**cfg))
  234. print(f" - 已加载项目 presets: {list(project_presets.keys())}")
  235. # 注册选题点推导专用工具(主 agent 与评估子 agent 会调用)
  236. import importlib.util
  237. tools_dir = base_dir / "tools"
  238. for mod_name, file_name in [
  239. ("find_tree_node", "find_tree_node.py"),
  240. ("find_pattern", "find_pattern.py"),
  241. ("point_match", "point_match.py"),
  242. ("search_and_eval", "search_and_eval.py"),
  243. ("pattern_dimension_analyze", "pattern_dimension_analyze.py"),
  244. ]:
  245. path = tools_dir / file_name
  246. if path.is_file():
  247. spec = importlib.util.spec_from_file_location(f"overall_derivation.{mod_name}", path)
  248. mod = importlib.util.module_from_spec(spec)
  249. spec.loader.exec_module(mod)
  250. print(f" - 已注册推导工具: {mod_name}")
  251. skills_dir = str(base_dir / "skills")
  252. print("=" * 60)
  253. print("选题点整体推导 Agent(交互增强)")
  254. print("=" * 60)
  255. print()
  256. print("💡 交互提示:")
  257. print(" - 执行过程中输入 'p' 或 'pause' 暂停并进入交互模式")
  258. print(" - 执行过程中输入 'q' 或 'quit' 停止执行")
  259. print("=" * 60)
  260. print()
  261. # 在读取 prompt 前生成 log_id(格式 yyyyMMddHHmmss),保证每次运行使用同一 log_id(用于推导日志输出路径)
  262. log_id = datetime.now().strftime("%Y%m%d%H%M%S")
  263. print(f" - 本次运行 log_id: {log_id}")
  264. print(f" - account_name: {account_name}")
  265. print(f" - post_id: {post_id}")
  266. # 读取选题点列表,得到 post_point_count(用于 prompt 占位符)
  267. input_dir = base_dir / "input" / account_name / "post_topic"
  268. post_topic_path = input_dir / f"{post_id}.json"
  269. post_point_count = 0
  270. if post_topic_path.exists():
  271. import json
  272. with open(post_topic_path, "r", encoding="utf-8") as f:
  273. post_topics = json.load(f)
  274. post_point_count = len(post_topics) if isinstance(post_topics, list) else 0
  275. print(f" - 选题点数量 post_point_count: {post_point_count} (来自 {post_topic_path.relative_to(base_dir)})")
  276. else:
  277. print(f" - 未找到选题点文件: {post_topic_path},post_point_count 使用 0")
  278. # 读取账号树数据,作为 {account_tree_data} prompt 占位符
  279. simple_tree_path = base_dir / "input" / account_name / "处理后数据" / "simple_tree" / "simple_tree.txt"
  280. account_tree_data = ""
  281. if simple_tree_path.exists():
  282. with open(simple_tree_path, "r", encoding="utf-8") as f:
  283. account_tree_data = f.read()
  284. print(f" - 已读取账号树数据: {simple_tree_path.relative_to(base_dir)}")
  285. else:
  286. print(f" - 未找到账号树数据文件: {simple_tree_path},{account_tree_data=} 将使用空字符串")
  287. print("1. 加载 prompt 配置...")
  288. prompt = SimplePrompt(prompt_path)
  289. print("2. 构建任务消息...")
  290. messages = prompt.build_messages()
  291. _replace_prompt_placeholders(
  292. messages,
  293. account_name,
  294. post_id,
  295. log_id,
  296. post_point_count,
  297. account_tree_data,
  298. )
  299. print("3. 创建 Agent Runner...")
  300. print(f" - Skills 目录: {skills_dir}")
  301. model_key = prompt.config.get("model", "google/gemini-3-flash-preview")
  302. # model_id = f"google/{model_key}" if not model_key.startswith("google/") else model_key
  303. model_id = model_key
  304. print(f" - 模型: {model_id}")
  305. store = FileSystemTraceStore(base_path=str(PROJECT_ROOT / ".trace"))
  306. runner = AgentRunner(
  307. trace_store=store,
  308. llm_call=create_openrouter_llm_call(model=model_id),
  309. skills_dir=skills_dir,
  310. # experiences_path="./.cache/experiences_overall_derivation.md",
  311. debug=True,
  312. )
  313. resume_trace_id = args.trace
  314. if resume_trace_id:
  315. existing_trace = await store.get_trace(resume_trace_id)
  316. if not existing_trace:
  317. print(f"\n错误: Trace 不存在: {resume_trace_id}")
  318. sys.exit(1)
  319. print(f"4. 恢复已有 Trace: {resume_trace_id[:8]}...")
  320. print(f" - 状态: {existing_trace.status}")
  321. print(f" - 消息数: {existing_trace.total_messages}")
  322. print(f" - 任务: {existing_trace.task}")
  323. else:
  324. print("4. 启动新 Agent 模式...")
  325. print()
  326. final_response = ""
  327. current_trace_id = resume_trace_id
  328. current_sequence = 0
  329. should_exit = False
  330. try:
  331. config = RUN_CONFIG
  332. if resume_trace_id:
  333. initial_messages = None
  334. config.trace_id = resume_trace_id
  335. else:
  336. initial_messages = messages
  337. while not should_exit:
  338. if current_trace_id:
  339. config.trace_id = current_trace_id
  340. final_response = ""
  341. if current_trace_id and initial_messages is None:
  342. check_trace = await store.get_trace(current_trace_id)
  343. if check_trace and check_trace.status in ("completed", "failed"):
  344. if check_trace.status == "completed":
  345. print("\n[Trace] ✅ 已完成")
  346. print(f" - Total messages: {check_trace.total_messages}")
  347. print(f" - Total cost: ${check_trace.total_cost:.4f}")
  348. else:
  349. print(f"\n[Trace] ❌ 已失败: {check_trace.error_message}")
  350. current_sequence = check_trace.head_sequence
  351. menu_result = await show_interactive_menu(
  352. runner, current_trace_id, current_sequence, store
  353. )
  354. if menu_result["action"] == "stop":
  355. break
  356. elif menu_result["action"] == "continue":
  357. new_messages = menu_result.get("messages", [])
  358. if new_messages:
  359. initial_messages = new_messages
  360. config.after_sequence = menu_result.get("after_sequence")
  361. else:
  362. initial_messages = []
  363. config.after_sequence = None
  364. continue
  365. break
  366. initial_messages = []
  367. print(f"{'▶️ 开始执行...' if not current_trace_id else '▶️ 继续执行...'}")
  368. paused = False
  369. try:
  370. async for item in runner.run(messages=initial_messages, config=config):
  371. cmd = check_stdin()
  372. if cmd == 'pause':
  373. print("\n⏸️ 正在暂停执行...")
  374. if current_trace_id:
  375. await runner.stop(current_trace_id)
  376. await asyncio.sleep(0.5)
  377. menu_result = await show_interactive_menu(
  378. runner, current_trace_id, current_sequence, store
  379. )
  380. if menu_result["action"] == "stop":
  381. should_exit = True
  382. paused = True
  383. break
  384. elif menu_result["action"] == "continue":
  385. new_messages = menu_result.get("messages", [])
  386. if new_messages:
  387. initial_messages = new_messages
  388. after_seq = menu_result.get("after_sequence")
  389. if after_seq is not None:
  390. config.after_sequence = after_seq
  391. paused = True
  392. break
  393. else:
  394. initial_messages = []
  395. config.after_sequence = None
  396. paused = True
  397. break
  398. elif cmd == 'quit':
  399. print("\n🛑 用户请求停止...")
  400. if current_trace_id:
  401. await runner.stop(current_trace_id)
  402. should_exit = True
  403. break
  404. if isinstance(item, Trace):
  405. current_trace_id = item.trace_id
  406. if item.status == "running":
  407. print(f"[Trace] 开始: {item.trace_id[:100]}...")
  408. elif item.status == "completed":
  409. print("\n[Trace] ✅ 完成")
  410. print(f" - Total messages: {item.total_messages}")
  411. print(f" - Total tokens: {item.total_tokens}")
  412. print(f" - Total cost: ${item.total_cost:.4f}")
  413. elif item.status == "failed":
  414. print(f"\n[Trace] ❌ 失败: {item.error_message}")
  415. elif item.status == "stopped":
  416. print("\n[Trace] ⏸️ 已停止")
  417. elif isinstance(item, Message):
  418. current_sequence = item.sequence
  419. if item.role == "assistant":
  420. content = item.content
  421. if isinstance(content, dict):
  422. text = content.get("text", "")
  423. tool_calls = content.get("tool_calls")
  424. if text and not tool_calls:
  425. final_response = text
  426. print("\n[Response] Agent 回复:")
  427. print(text)
  428. elif text:
  429. preview = text[:500] + "..." if len(text) > 500 else text
  430. print(f"[Assistant] {preview}")
  431. if tool_calls:
  432. for tc in tool_calls:
  433. tool_name = tc.get("function", {}).get("name", "unknown")
  434. tool_args = tc.get("function", {}).get("arguments", "")
  435. print(f"[Tool Call] 🛠️ {tool_name}")
  436. print(f" params: {tool_args}")
  437. elif item.role == "tool":
  438. content = item.content
  439. if isinstance(content, dict):
  440. tool_name = content.get("tool_name", "unknown")
  441. tool_result = content.get("result", content)
  442. print(f"[Tool Result] ✅ {tool_name}")
  443. print(f" result: {tool_result}")
  444. if item.description:
  445. desc = item.description[:500] if len(item.description) > 500 else item.description
  446. print(f" {desc}...")
  447. except Exception as e:
  448. print(f"\n执行出错: {e}")
  449. import traceback
  450. traceback.print_exc()
  451. if paused:
  452. if should_exit:
  453. break
  454. continue
  455. if should_exit:
  456. break
  457. if current_trace_id:
  458. menu_result = await show_interactive_menu(
  459. runner, current_trace_id, current_sequence, store
  460. )
  461. if menu_result["action"] == "stop":
  462. break
  463. elif menu_result["action"] == "continue":
  464. new_messages = menu_result.get("messages", [])
  465. if new_messages:
  466. initial_messages = new_messages
  467. config.after_sequence = menu_result.get("after_sequence")
  468. else:
  469. initial_messages = []
  470. config.after_sequence = None
  471. continue
  472. break
  473. except KeyboardInterrupt:
  474. print("\n\n用户中断 (Ctrl+C)")
  475. if current_trace_id:
  476. await runner.stop(current_trace_id)
  477. if final_response:
  478. print()
  479. print("=" * 60)
  480. print("Agent 响应:")
  481. print("=" * 60)
  482. print(final_response)
  483. print("=" * 60)
  484. print()
  485. # 以脚本所在目录为基准,兼容从项目根或脚本目录启动
  486. script_dir = Path(__file__).resolve().parent
  487. output_dir = script_dir / "output"
  488. output_file = output_dir / account_name / "推导日志" / post_id / log_id / "result.txt"
  489. output_file.parent.mkdir(parents=True, exist_ok=True)
  490. with open(output_file, 'w', encoding='utf-8') as f:
  491. f.write(final_response)
  492. print(f"✓ 结果已保存到: {output_file}")
  493. print()
  494. if current_trace_id:
  495. print("=" * 60)
  496. print("可视化 Step Tree:")
  497. print("=" * 60)
  498. print("1. 启动 API Server:")
  499. print(" python3 api_server.py")
  500. print()
  501. print("2. 浏览器访问:")
  502. print(" http://localhost:8000/api/traces")
  503. print()
  504. print(f"3. Trace ID: {current_trace_id}")
  505. print(f"4. Log ID(推导日志目录): {log_id}")
  506. print("=" * 60)
  507. if __name__ == "__main__":
  508. # anthropic/claude-sonnet-4.6
  509. # google/gemini-3-flash-preview
  510. asyncio.run(main(account_name="家有大志", post_id="68fb6a5c000000000302e5de"))