run.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. """
  2. 内容树需求归纳 Agent
  3. 从内容树节点归纳制作需求(任务1)。
  4. 用法:
  5. python run.py # 直接运行(任务在 analyst.prompt 中配置)
  6. python run.py --trace <TRACE_ID> # 恢复已有 trace
  7. """
  8. import argparse
  9. import os
  10. import sys
  11. import asyncio
  12. from pathlib import Path
  13. os.environ.setdefault("no_proxy", "*")
  14. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  15. from dotenv import load_dotenv
  16. load_dotenv()
  17. from agent.llm.prompts import SimplePrompt
  18. from agent.core.runner import AgentRunner, RunConfig
  19. from agent.trace import FileSystemTraceStore, Trace, Message
  20. from agent.llm import create_qwen_llm_call
  21. from agent.cli import InteractiveController
  22. from agent.utils import setup_logging
  23. # 注册自定义工具
  24. from tools.content_tree import search_content_tree, get_category_tree
  25. from tools.frequent_itemsets import get_frequent_itemsets
  26. from config import RUN_CONFIG, SKILLS_DIR, TRACE_STORE_PATH, DEBUG, LOG_LEVEL, LOG_FILE, OUTPUT_DIR
  27. async def main():
  28. parser = argparse.ArgumentParser(description="内容树需求归纳 Agent")
  29. parser.add_argument("--trace", type=str, default=None, help="已有 Trace ID,用于恢复继续执行")
  30. args = parser.parse_args()
  31. base_dir = Path(__file__).parent
  32. project_root = base_dir.parent.parent
  33. output_dir = project_root / OUTPUT_DIR
  34. output_dir.mkdir(parents=True, exist_ok=True)
  35. setup_logging(level=LOG_LEVEL, file=LOG_FILE)
  36. # 加载 presets
  37. presets_path = base_dir / "presets.json"
  38. if presets_path.exists():
  39. from agent.core.presets import load_presets_from_json
  40. load_presets_from_json(str(presets_path))
  41. print("已加载 presets")
  42. # 构建任务消息(直接从 analyst.prompt 加载,无变量替换)
  43. if not args.trace:
  44. prompt = SimplePrompt(base_dir / "analyst.prompt")
  45. messages = prompt.build_messages(output_dir=OUTPUT_DIR)
  46. else:
  47. messages = None
  48. # 创建 Runner
  49. store = FileSystemTraceStore(base_path=TRACE_STORE_PATH)
  50. runner = AgentRunner(
  51. trace_store=store,
  52. llm_call=create_qwen_llm_call(model=RUN_CONFIG.model),
  53. skills_dir=SKILLS_DIR,
  54. debug=DEBUG,
  55. )
  56. interactive = InteractiveController(runner=runner, store=store, enable_stdin_check=True)
  57. runner.stdin_check = interactive.check_stdin
  58. print("=" * 60)
  59. print("内容树需求归纳 Agent")
  60. print("=" * 60)
  61. print("💡 输入 'p' 暂停,'q' 退出")
  62. print("=" * 60)
  63. run_config = RUN_CONFIG
  64. resume_trace_id = args.trace
  65. current_trace_id = resume_trace_id
  66. current_sequence = 0
  67. should_exit = False
  68. final_response = ""
  69. try:
  70. if resume_trace_id:
  71. existing = await store.get_trace(resume_trace_id)
  72. if not existing:
  73. print(f"错误: Trace 不存在: {resume_trace_id}")
  74. sys.exit(1)
  75. run_config.trace_id = resume_trace_id
  76. print(f"恢复 Trace: {resume_trace_id[:8]}...")
  77. while not should_exit:
  78. if current_trace_id:
  79. run_config.trace_id = current_trace_id
  80. # 恢复模式:先进交互菜单
  81. if current_trace_id and messages is None:
  82. check_trace = await store.get_trace(current_trace_id)
  83. if check_trace:
  84. current_sequence = check_trace.head_sequence
  85. menu_result = await interactive.show_menu(current_trace_id, current_sequence)
  86. if menu_result["action"] == "stop":
  87. break
  88. elif menu_result["action"] == "continue":
  89. new_msgs = menu_result.get("messages", [])
  90. messages = new_msgs if new_msgs else []
  91. run_config.after_sequence = menu_result.get("after_sequence")
  92. continue
  93. break
  94. if messages is None:
  95. messages = []
  96. print("▶️ 开始执行...")
  97. paused = False
  98. try:
  99. async for item in runner.run(messages=messages, config=run_config):
  100. cmd = interactive.check_stdin()
  101. if cmd == "pause":
  102. print("\n⏸️ 暂停中...")
  103. if current_trace_id:
  104. await runner.stop(current_trace_id)
  105. await asyncio.sleep(0.5)
  106. menu_result = await interactive.show_menu(current_trace_id, current_sequence)
  107. if menu_result["action"] == "stop":
  108. should_exit = True
  109. paused = True
  110. break
  111. elif menu_result["action"] == "continue":
  112. new_msgs = menu_result.get("messages", [])
  113. messages = new_msgs if new_msgs else []
  114. run_config.after_sequence = menu_result.get("after_sequence")
  115. paused = True
  116. break
  117. elif cmd == "quit":
  118. print("\n🛑 停止...")
  119. if current_trace_id:
  120. await runner.stop(current_trace_id)
  121. should_exit = True
  122. break
  123. if isinstance(item, Trace):
  124. current_trace_id = item.trace_id
  125. if item.status == "running":
  126. print(f"[Trace] 开始: {item.trace_id[:8]}...")
  127. elif item.status == "completed":
  128. print(f"\n[Trace] ✅ 完成 | messages={item.total_messages} | cost=${item.total_cost:.4f}")
  129. elif item.status == "failed":
  130. print(f"\n[Trace] ❌ 失败: {item.error_message}")
  131. elif isinstance(item, Message):
  132. current_sequence = item.sequence
  133. if item.role == "assistant":
  134. content = item.content
  135. if isinstance(content, dict):
  136. text = content.get("text", "")
  137. tool_calls = content.get("tool_calls")
  138. if text and not tool_calls:
  139. final_response = text
  140. print(f"\n[Response]\n{text}")
  141. elif text:
  142. preview = text[:150] + "..." if len(text) > 150 else text
  143. print(f"[Assistant] {preview}")
  144. elif item.role == "tool":
  145. content = item.content
  146. tool_name = content.get("tool_name", "unknown") if isinstance(content, dict) else "unknown"
  147. desc = item.description or ""
  148. if desc and desc != tool_name:
  149. print(f"[Tool] ✅ {tool_name}: {desc[:80]}")
  150. else:
  151. print(f"[Tool] ✅ {tool_name}")
  152. except Exception as e:
  153. print(f"\n执行出错: {e}")
  154. import traceback
  155. traceback.print_exc()
  156. if paused:
  157. if should_exit:
  158. break
  159. continue
  160. if should_exit:
  161. break
  162. if current_trace_id:
  163. menu_result = await interactive.show_menu(current_trace_id, current_sequence)
  164. if menu_result["action"] == "stop":
  165. break
  166. elif menu_result["action"] == "continue":
  167. new_msgs = menu_result.get("messages", [])
  168. messages = new_msgs if new_msgs else []
  169. run_config.after_sequence = menu_result.get("after_sequence")
  170. continue
  171. break
  172. except KeyboardInterrupt:
  173. print("\n用户中断 (Ctrl+C)")
  174. if current_trace_id:
  175. await runner.stop(current_trace_id)
  176. # 保存最终结果
  177. if final_response:
  178. result_file = output_dir / "result.txt"
  179. result_file.write_text(final_response, encoding="utf-8")
  180. print(f"\n✓ 结果已保存: {result_file}")
  181. if current_trace_id:
  182. print(f"\nTrace ID: {current_trace_id}")
  183. print("可视化: python3 api_server.py → http://localhost:8000/api/traces")
  184. if __name__ == "__main__":
  185. asyncio.run(main())