| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221 |
- """
- 内容树需求归纳 Agent
- 从内容树节点归纳制作需求(任务1)。
- 用法:
- python run.py # 直接运行(任务在 analyst.prompt 中配置)
- python run.py --trace <TRACE_ID> # 恢复已有 trace
- """
- import argparse
- import os
- import sys
- import asyncio
- from pathlib import Path
- os.environ.setdefault("no_proxy", "*")
- sys.path.insert(0, str(Path(__file__).parent.parent.parent))
- from dotenv import load_dotenv
- load_dotenv()
- from agent.llm.prompts import SimplePrompt
- from agent.core.runner import AgentRunner, RunConfig
- from agent.trace import FileSystemTraceStore, Trace, Message
- from agent.llm import create_qwen_llm_call
- from agent.cli import InteractiveController
- from agent.utils import setup_logging
- # 注册自定义工具
- from tools.content_tree import search_content_tree, get_category_tree
- from tools.frequent_itemsets import get_frequent_itemsets
- from config import RUN_CONFIG, SKILLS_DIR, TRACE_STORE_PATH, DEBUG, LOG_LEVEL, LOG_FILE, OUTPUT_DIR
- async def main():
- parser = argparse.ArgumentParser(description="内容树需求归纳 Agent")
- parser.add_argument("--trace", type=str, default=None, help="已有 Trace ID,用于恢复继续执行")
- args = parser.parse_args()
- base_dir = Path(__file__).parent
- project_root = base_dir.parent.parent
- output_dir = project_root / OUTPUT_DIR
- output_dir.mkdir(parents=True, exist_ok=True)
- setup_logging(level=LOG_LEVEL, file=LOG_FILE)
- # 加载 presets
- presets_path = base_dir / "presets.json"
- if presets_path.exists():
- from agent.core.presets import load_presets_from_json
- load_presets_from_json(str(presets_path))
- print("已加载 presets")
- # 构建任务消息(直接从 analyst.prompt 加载,无变量替换)
- if not args.trace:
- prompt = SimplePrompt(base_dir / "analyst.prompt")
- messages = prompt.build_messages(output_dir=OUTPUT_DIR)
- else:
- messages = None
- # 创建 Runner
- store = FileSystemTraceStore(base_path=TRACE_STORE_PATH)
- runner = AgentRunner(
- trace_store=store,
- llm_call=create_qwen_llm_call(model=RUN_CONFIG.model),
- skills_dir=SKILLS_DIR,
- debug=DEBUG,
- )
- interactive = InteractiveController(runner=runner, store=store, enable_stdin_check=True)
- runner.stdin_check = interactive.check_stdin
- print("=" * 60)
- print("内容树需求归纳 Agent")
- print("=" * 60)
- print("💡 输入 'p' 暂停,'q' 退出")
- print("=" * 60)
- run_config = RUN_CONFIG
- resume_trace_id = args.trace
- current_trace_id = resume_trace_id
- current_sequence = 0
- should_exit = False
- final_response = ""
- try:
- if resume_trace_id:
- existing = await store.get_trace(resume_trace_id)
- if not existing:
- print(f"错误: Trace 不存在: {resume_trace_id}")
- sys.exit(1)
- run_config.trace_id = resume_trace_id
- print(f"恢复 Trace: {resume_trace_id[:8]}...")
- while not should_exit:
- if current_trace_id:
- run_config.trace_id = current_trace_id
- # 恢复模式:先进交互菜单
- if current_trace_id and messages is None:
- check_trace = await store.get_trace(current_trace_id)
- if check_trace:
- current_sequence = check_trace.head_sequence
- menu_result = await interactive.show_menu(current_trace_id, current_sequence)
- if menu_result["action"] == "stop":
- break
- elif menu_result["action"] == "continue":
- new_msgs = menu_result.get("messages", [])
- messages = new_msgs if new_msgs else []
- run_config.after_sequence = menu_result.get("after_sequence")
- continue
- break
- if messages is None:
- messages = []
- print("▶️ 开始执行...")
- paused = False
- try:
- async for item in runner.run(messages=messages, config=run_config):
- cmd = interactive.check_stdin()
- if cmd == "pause":
- print("\n⏸️ 暂停中...")
- if current_trace_id:
- await runner.stop(current_trace_id)
- await asyncio.sleep(0.5)
- menu_result = await interactive.show_menu(current_trace_id, current_sequence)
- if menu_result["action"] == "stop":
- should_exit = True
- paused = True
- break
- elif menu_result["action"] == "continue":
- new_msgs = menu_result.get("messages", [])
- messages = new_msgs if new_msgs else []
- run_config.after_sequence = menu_result.get("after_sequence")
- paused = True
- break
- elif cmd == "quit":
- print("\n🛑 停止...")
- if current_trace_id:
- await runner.stop(current_trace_id)
- should_exit = True
- break
- if isinstance(item, Trace):
- current_trace_id = item.trace_id
- if item.status == "running":
- print(f"[Trace] 开始: {item.trace_id[:8]}...")
- elif item.status == "completed":
- print(f"\n[Trace] ✅ 完成 | messages={item.total_messages} | cost=${item.total_cost:.4f}")
- elif item.status == "failed":
- print(f"\n[Trace] ❌ 失败: {item.error_message}")
- elif isinstance(item, Message):
- current_sequence = item.sequence
- if item.role == "assistant":
- content = item.content
- if isinstance(content, dict):
- text = content.get("text", "")
- tool_calls = content.get("tool_calls")
- if text and not tool_calls:
- final_response = text
- print(f"\n[Response]\n{text}")
- elif text:
- preview = text[:150] + "..." if len(text) > 150 else text
- print(f"[Assistant] {preview}")
- elif item.role == "tool":
- content = item.content
- tool_name = content.get("tool_name", "unknown") if isinstance(content, dict) else "unknown"
- desc = item.description or ""
- if desc and desc != tool_name:
- print(f"[Tool] ✅ {tool_name}: {desc[:80]}")
- else:
- print(f"[Tool] ✅ {tool_name}")
- except Exception as e:
- print(f"\n执行出错: {e}")
- import traceback
- traceback.print_exc()
- if paused:
- if should_exit:
- break
- continue
- if should_exit:
- break
- if current_trace_id:
- menu_result = await interactive.show_menu(current_trace_id, current_sequence)
- if menu_result["action"] == "stop":
- break
- elif menu_result["action"] == "continue":
- new_msgs = menu_result.get("messages", [])
- messages = new_msgs if new_msgs else []
- run_config.after_sequence = menu_result.get("after_sequence")
- continue
- break
- except KeyboardInterrupt:
- print("\n用户中断 (Ctrl+C)")
- if current_trace_id:
- await runner.stop(current_trace_id)
- # 保存最终结果
- if final_response:
- result_file = output_dir / "result.txt"
- result_file.write_text(final_response, encoding="utf-8")
- print(f"\n✓ 结果已保存: {result_file}")
- if current_trace_id:
- print(f"\nTrace ID: {current_trace_id}")
- print("可视化: python3 api_server.py → http://localhost:8000/api/traces")
- if __name__ == "__main__":
- asyncio.run(main())
|