run.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549
  1. """
  2. 示例(流程对齐版)
  3. 参考 examples/research/run.py:
  4. 1. 使用框架 InteractiveController 统一交互流程
  5. 2. 使用 config.py 管理运行参数
  6. 3. 保留 create 场景特有的 prompt 注入与详细消息打印
  7. """
  8. import argparse
  9. import asyncio
  10. import copy
  11. import json
  12. import os
  13. import sys
  14. from pathlib import Path
  15. from typing import Any
  16. import logging
  17. # Clash Verge TUN 模式兼容:禁止 httpx/urllib 自动检测系统 HTTP 代理
  18. os.environ.setdefault("no_proxy", "*")
  19. logger = logging.getLogger(__name__)
  20. # 添加项目根目录到 Python 路径
  21. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  22. from dotenv import load_dotenv
  23. load_dotenv()
  24. from agent.cli import InteractiveController
  25. from agent.core.presets import AgentPreset, register_preset
  26. from agent.core.runner import AgentRunner
  27. from agent.llm import create_openrouter_llm_call
  28. from agent.llm.prompts import SimplePrompt
  29. from agent.trace import FileSystemTraceStore, Message, Trace
  30. from agent.utils import setup_logging
  31. from examples.create.html import trace_to_html
  32. # 导入项目配置
  33. from config import DEBUG, LOG_FILE, LOG_LEVEL, RUN_CONFIG, SKILLS_DIR, TRACE_STORE_PATH
  34. def _format_json(obj: Any, indent: int = 2) -> str:
  35. """格式化 JSON 对象为字符串"""
  36. try:
  37. return json.dumps(obj, indent=indent, ensure_ascii=False)
  38. except (TypeError, ValueError):
  39. return str(obj)
  40. def _print_message_details(message: Message):
  41. """完整打印消息的详细信息"""
  42. logger.info("\n" + "=" * 80)
  43. logger.info(f"[Message #{message.sequence}] {message.role.upper()}")
  44. logger.info("=" * 80)
  45. if message.goal_id:
  46. logger.info(f"Goal ID: {message.goal_id}")
  47. if message.parent_sequence is not None:
  48. logger.info(f"Parent Sequence: {message.parent_sequence}")
  49. if message.tool_call_id:
  50. logger.info(f"Tool Call ID: {message.tool_call_id}")
  51. if message.role == "user":
  52. logger.info("\n[输入内容]")
  53. logger.info("-" * 80)
  54. if isinstance(message.content, str):
  55. logger.info(message.content)
  56. else:
  57. logger.info(_format_json(message.content))
  58. elif message.role == "assistant":
  59. content = message.content
  60. if isinstance(content, dict):
  61. text = content.get("text", "")
  62. tool_calls = content.get("tool_calls")
  63. if text:
  64. logger.info("\n[LLM 文本回复]")
  65. logger.info("-" * 80)
  66. logger.info(text)
  67. if tool_calls:
  68. logger.info(f"\n[工具调用] (共 {len(tool_calls)} 个)")
  69. logger.info("-" * 80)
  70. for idx, tc in enumerate(tool_calls, 1):
  71. func = tc.get("function", {})
  72. tool_name = func.get("name", "unknown")
  73. tool_id = tc.get("id", "unknown")
  74. arguments = func.get("arguments", {})
  75. logger.info(f"\n工具 #{idx}: {tool_name}")
  76. logger.info(f" Call ID: {tool_id}")
  77. logger.info(" 参数:")
  78. if isinstance(arguments, str):
  79. try:
  80. parsed_args = json.loads(arguments)
  81. logger.info(_format_json(parsed_args, indent=4))
  82. except json.JSONDecodeError:
  83. logger.info(f" {arguments}")
  84. else:
  85. logger.info(_format_json(arguments, indent=4))
  86. elif isinstance(content, str):
  87. logger.info("\n[LLM 文本回复]")
  88. logger.info("-" * 80)
  89. logger.info(content)
  90. else:
  91. logger.info("\n[内容]")
  92. logger.info("-" * 80)
  93. logger.info(_format_json(content))
  94. if message.finish_reason:
  95. logger.info(f"\n完成原因: {message.finish_reason}")
  96. elif message.role == "tool":
  97. content = message.content
  98. logger.info("\n[工具执行结果]")
  99. logger.info("-" * 80)
  100. if isinstance(content, dict):
  101. tool_name = content.get("tool_name", "unknown")
  102. result = content.get("result", content)
  103. logger.info(f"工具名称: {tool_name}")
  104. logger.info("\n返回结果:")
  105. if isinstance(result, str):
  106. logger.info(result)
  107. elif isinstance(result, list):
  108. for idx, item in enumerate(result, 1):
  109. if isinstance(item, dict) and item.get("type") == "image_url":
  110. logger.info(f" [{idx}] 图片 (base64, 已省略显示)")
  111. else:
  112. logger.info(f" [{idx}] {item}")
  113. else:
  114. logger.info(_format_json(result))
  115. else:
  116. logger.info(str(content) if content is not None else "(无内容)")
  117. elif message.role == "system":
  118. logger.info("\n[系统提示]")
  119. logger.info("-" * 80)
  120. if isinstance(message.content, str):
  121. logger.info(message.content)
  122. else:
  123. logger.info(_format_json(message.content))
  124. if message.prompt_tokens is not None or message.completion_tokens is not None:
  125. logger.info("\n[Token 使用]")
  126. logger.info("-" * 80)
  127. if message.prompt_tokens is not None:
  128. logger.info(f" 输入 Tokens: {message.prompt_tokens:,}")
  129. if message.completion_tokens is not None:
  130. logger.info(f" 输出 Tokens: {message.completion_tokens:,}")
  131. if message.reasoning_tokens is not None:
  132. logger.info(f" 推理 Tokens: {message.reasoning_tokens:,}")
  133. if message.cache_creation_tokens is not None:
  134. logger.info(f" 缓存创建 Tokens: {message.cache_creation_tokens:,}")
  135. if message.cache_read_tokens is not None:
  136. logger.info(f" 缓存读取 Tokens: {message.cache_read_tokens:,}")
  137. if message.tokens:
  138. logger.info(f" 总计 Tokens: {message.tokens:,}")
  139. if message.cost is not None:
  140. logger.info(f"\n[成本] ${message.cost:.6f}")
  141. if message.duration_ms is not None:
  142. logger.info(f"[执行时间] {message.duration_ms}ms")
  143. logger.info("=" * 80 + "\n")
  144. def _apply_prompt_placeholders(base_dir: Path, prompt: SimplePrompt, persona_dir: str = None):
  145. """把 PRD 文件内容和人设树数据注入 prompt 占位符。
  146. Args:
  147. base_dir: 基础目录
  148. prompt: SimplePrompt 对象
  149. persona_dir: 人设数据目录名,如 "家有大志"。如果为 None,则不替换树数据
  150. """
  151. # 替换 {{person_name}} 占位符
  152. if persona_dir:
  153. person_name_placeholder = "{{person_name}}"
  154. if "system" in prompt._messages and person_name_placeholder in prompt._messages["system"]:
  155. prompt._messages["system"] = prompt._messages["system"].replace(person_name_placeholder, persona_dir)
  156. logger.info(f" - 已替换 {{{{person_name}}}} 为: {persona_dir}")
  157. if "user" in prompt._messages and person_name_placeholder in prompt._messages["user"]:
  158. prompt._messages["user"] = prompt._messages["user"].replace(person_name_placeholder, persona_dir)
  159. logger.info(f" - 已替换 {{{{person_name}}}} 为: {persona_dir} (user)")
  160. system_md_path = base_dir / "PRD" / "system.md"
  161. if system_md_path.exists():
  162. system_content = system_md_path.read_text(encoding="utf-8")
  163. if "system" in prompt._messages and "{system}" in prompt._messages["system"]:
  164. prompt._messages["system"] = prompt._messages["system"].replace("{system}", system_content)
  165. else:
  166. logger.warning(f" - 警告: system.md 文件不存在: {system_md_path}")
  167. # 优先使用 v2 版本,如果不存在则使用原版本
  168. create_process_md_path = base_dir / "PRD" / "create_process_v2.md"
  169. if not create_process_md_path.exists():
  170. create_process_md_path = base_dir / "PRD" / "create_process.md"
  171. if create_process_md_path.exists():
  172. create_process_content = create_process_md_path.read_text(encoding="utf-8")
  173. if "system" in prompt._messages and "{create_process}" in prompt._messages["system"]:
  174. prompt._messages["system"] = prompt._messages["system"].replace("{create_process}", create_process_content)
  175. logger.info(f" - 已替换 {create_process_md_path.name} 内容到 prompt")
  176. else:
  177. logger.warning(" - 警告: prompt 中未找到 {create_process} 占位符")
  178. else:
  179. logger.warning(f" - 警告: create_process.md 文件不存在: {create_process_md_path}")
  180. # 替换人设树数据
  181. if persona_dir:
  182. tree_dir = base_dir / "data" / persona_dir / "tree"
  183. if tree_dir.exists():
  184. # 读取三个树文件
  185. tree_files = {
  186. "形式_point_tree_how": tree_dir / "形式_point_tree_how.json",
  187. "实质_point_tree_how": tree_dir / "实质_point_tree_how.json",
  188. "意图_point_tree_how": tree_dir / "意图_point_tree_how.json"
  189. }
  190. for var_name, tree_path in tree_files.items():
  191. if tree_path.exists():
  192. tree_content = tree_path.read_text(encoding="utf-8")
  193. placeholder = "{{" + var_name + "}}"
  194. # 在 system 消息中替换
  195. if "system" in prompt._messages and placeholder in prompt._messages["system"]:
  196. prompt._messages["system"] = prompt._messages["system"].replace(placeholder, tree_content)
  197. logger.info(f" - 已替换 {var_name} 数据到 prompt")
  198. # 在 user 消息中替换
  199. if "user" in prompt._messages and placeholder in prompt._messages["user"]:
  200. prompt._messages["user"] = prompt._messages["user"].replace(placeholder, tree_content)
  201. logger.info(f" - 已替换 {var_name} 数据到 prompt (user)")
  202. else:
  203. logger.warning(f" - 警告: 树文件不存在: {tree_path}")
  204. else:
  205. logger.warning(f" - 警告: 人设树目录不存在: {tree_dir}")
  206. input_md_path = base_dir / "PRD" / "input.md"
  207. if input_md_path.exists():
  208. user_content = input_md_path.read_text(encoding="utf-8")
  209. if "user" in prompt._messages and "{input}" in prompt._messages["user"]:
  210. prompt._messages["user"] = prompt._messages["user"].replace("{input}", user_content)
  211. logger.info(" - 已替换 input.md 内容到 prompt")
  212. else:
  213. logger.warning(" - 警告: prompt 中未找到 {input} 占位符")
  214. else:
  215. logger.warning(f" - 警告: input.md 文件不存在: {input_md_path}")
  216. output_md_path = base_dir / "PRD" / "output.md"
  217. if output_md_path.exists():
  218. output_content = output_md_path.read_text(encoding="utf-8")
  219. if "user" in prompt._messages and "{output}" in prompt._messages["user"]:
  220. prompt._messages["user"] = prompt._messages["user"].replace("{output}", output_content)
  221. logger.info(" - 已替换 output.md 内容到 prompt")
  222. else:
  223. logger.warning(" - 警告: prompt 中未找到 {output} 占位符")
  224. else:
  225. logger.warning(f" - 警告: output.md 文件不存在: {output_md_path}")
  226. async def main():
  227. parser = argparse.ArgumentParser(description="任务 (Agent 模式 + 交互增强)")
  228. parser.add_argument(
  229. "--trace",
  230. type=str,
  231. default=None,
  232. help="已有的 Trace ID,用于恢复继续执行(不指定则新建)",
  233. )
  234. parser.add_argument(
  235. "--persona",
  236. type=str,
  237. default=None,
  238. help="人设数据目录名,如 '家有大志'。用于读取 data/{目录名}/tree 下的树数据",
  239. )
  240. args = parser.parse_args()
  241. base_dir = Path(__file__).parent
  242. prompt_path = base_dir / "create.prompt"
  243. output_dir = base_dir / "output_1"
  244. output_dir.mkdir(exist_ok=True)
  245. setup_logging(level=LOG_LEVEL, file=LOG_FILE)
  246. logger.info("2. 加载 presets...")
  247. presets_path = base_dir / "presets.json"
  248. if presets_path.exists():
  249. with open(presets_path, "r", encoding="utf-8") as f:
  250. project_presets = json.load(f)
  251. for name, cfg in project_presets.items():
  252. register_preset(name, AgentPreset(**cfg))
  253. logger.info(f" - 已加载项目 presets: {list(project_presets.keys())}")
  254. logger.info("3. 加载 prompt...")
  255. prompt = SimplePrompt(prompt_path)
  256. _apply_prompt_placeholders(base_dir, prompt, persona_dir=args.persona)
  257. logger.info("\n替换后的 prompt:")
  258. logger.info("=" * 60)
  259. logger.info("System:")
  260. logger.info("-" * 60)
  261. logger.info(prompt._messages.get("system", ""))
  262. logger.info("=" * 60)
  263. if "user" in prompt._messages:
  264. logger.info("\nUser:")
  265. logger.info("-" * 60)
  266. logger.info(prompt._messages["user"])
  267. logger.info("=" * 60)
  268. logger.info("")
  269. logger.info("4. 构建任务消息...")
  270. messages = prompt.build_messages()
  271. logger.info("5. 创建 Agent Runner...")
  272. logger.info(" - 加载自定义工具: topic_search")
  273. import examples.create.tool # noqa: F401
  274. model_from_prompt = prompt.config.get("model")
  275. model_from_config = RUN_CONFIG.model
  276. default_model = f"anthropic/{model_from_config}" if "/" not in model_from_config else model_from_config
  277. model = model_from_prompt or default_model
  278. skills_dir = str((base_dir / SKILLS_DIR).resolve()) if not Path(SKILLS_DIR).is_absolute() else SKILLS_DIR
  279. logger.info(f" - Skills 目录: {skills_dir}")
  280. logger.info(f" - 模型: {model}")
  281. store = FileSystemTraceStore(base_path=TRACE_STORE_PATH)
  282. runner = AgentRunner(
  283. trace_store=store,
  284. llm_call=create_openrouter_llm_call(model=model),
  285. skills_dir=skills_dir,
  286. debug=DEBUG,
  287. )
  288. interactive = InteractiveController(
  289. runner=runner,
  290. store=store,
  291. enable_stdin_check=True,
  292. )
  293. task_name = RUN_CONFIG.name or base_dir.name
  294. logger.info("=" * 60)
  295. logger.info(task_name)
  296. logger.info("=" * 60)
  297. logger.info("💡 交互提示:")
  298. logger.info(" - 执行过程中输入 'p' 或 'pause' 暂停并进入交互模式")
  299. logger.info(" - 执行过程中输入 'q' 或 'quit' 停止执行")
  300. logger.info("=" * 60)
  301. logger.info("")
  302. resume_trace_id = args.trace
  303. if resume_trace_id:
  304. existing_trace = await store.get_trace(resume_trace_id)
  305. if not existing_trace:
  306. logger.error(f"\n错误: Trace 不存在: {resume_trace_id}")
  307. sys.exit(1)
  308. logger.info(f"恢复已有 Trace: {resume_trace_id[:8]}...")
  309. logger.info(f" - 状态: {existing_trace.status}")
  310. logger.info(f" - 消息数: {existing_trace.total_messages}")
  311. else:
  312. logger.info("启动新 Agent...")
  313. logger.info("")
  314. final_response = ""
  315. current_trace_id = resume_trace_id
  316. current_sequence = 0
  317. should_exit = False
  318. try:
  319. run_config = copy.deepcopy(RUN_CONFIG)
  320. run_config.model = model
  321. run_config.temperature = float(prompt.config.get("temperature", run_config.temperature))
  322. run_config.max_iterations = int(prompt.config.get("max_iterations", run_config.max_iterations))
  323. if resume_trace_id:
  324. initial_messages = None
  325. run_config.trace_id = resume_trace_id
  326. else:
  327. initial_messages = messages
  328. run_config.name = "社交媒体内容解构、建构、评估任务"
  329. while not should_exit:
  330. if current_trace_id:
  331. run_config.trace_id = current_trace_id
  332. final_response = ""
  333. if current_trace_id and initial_messages is None:
  334. check_trace = await store.get_trace(current_trace_id)
  335. if check_trace and check_trace.status in ("completed", "failed"):
  336. if check_trace.status == "completed":
  337. logger.info("\n[Trace] ✅ 已完成")
  338. logger.info(f" - Total messages: {check_trace.total_messages}")
  339. logger.info(f" - Total cost: ${check_trace.total_cost:.4f}")
  340. else:
  341. logger.error(f"\n[Trace] ❌ 已失败: {check_trace.error_message}")
  342. current_sequence = check_trace.head_sequence
  343. menu_result = await interactive.show_menu(current_trace_id, current_sequence)
  344. if menu_result["action"] == "stop":
  345. break
  346. if menu_result["action"] == "continue":
  347. new_messages = menu_result.get("messages", [])
  348. if new_messages:
  349. initial_messages = new_messages
  350. run_config.after_sequence = menu_result.get("after_sequence")
  351. else:
  352. initial_messages = []
  353. run_config.after_sequence = None
  354. continue
  355. break
  356. initial_messages = []
  357. logger.info(f"{'▶️ 开始执行...' if not current_trace_id else '▶️ 继续执行...'}")
  358. paused = False
  359. try:
  360. async for item in runner.run(messages=initial_messages, config=run_config):
  361. cmd = interactive.check_stdin()
  362. if cmd == "pause":
  363. logger.info("\n⏸️ 正在暂停执行...")
  364. if current_trace_id:
  365. await runner.stop(current_trace_id)
  366. await asyncio.sleep(0.5)
  367. menu_result = await interactive.show_menu(current_trace_id, current_sequence)
  368. if menu_result["action"] == "stop":
  369. should_exit = True
  370. paused = True
  371. break
  372. if menu_result["action"] == "continue":
  373. new_messages = menu_result.get("messages", [])
  374. if new_messages:
  375. initial_messages = new_messages
  376. after_seq = menu_result.get("after_sequence")
  377. if after_seq is not None:
  378. run_config.after_sequence = after_seq
  379. else:
  380. initial_messages = []
  381. run_config.after_sequence = None
  382. paused = True
  383. break
  384. elif cmd == "quit":
  385. logger.info("\n🛑 用户请求停止...")
  386. if current_trace_id:
  387. await runner.stop(current_trace_id)
  388. should_exit = True
  389. break
  390. if isinstance(item, Trace):
  391. current_trace_id = item.trace_id
  392. if item.status == "running":
  393. logger.info(f"[Trace] 开始: {item.trace_id[:8]}...")
  394. elif item.status == "completed":
  395. logger.info("\n[Trace] ✅ 完成")
  396. logger.info(f" - Total messages: {item.total_messages}")
  397. logger.info(f" - Total tokens: {item.total_tokens}")
  398. logger.info(f" - Total cost: ${item.total_cost:.4f}")
  399. elif item.status == "failed":
  400. logger.error(f"\n[Trace] ❌ 失败: {item.error_message}")
  401. elif item.status == "stopped":
  402. logger.info("\n[Trace] ⏸️ 已停止")
  403. elif isinstance(item, Message):
  404. current_sequence = item.sequence
  405. _print_message_details(item)
  406. if item.role == "assistant":
  407. content = item.content
  408. if isinstance(content, dict):
  409. text = content.get("text", "")
  410. tool_calls = content.get("tool_calls")
  411. if text and not tool_calls:
  412. final_response = text
  413. except Exception as e:
  414. logger.error(f"\n执行出错: {e}")
  415. logger.exception("Exception details:")
  416. if paused:
  417. if should_exit:
  418. break
  419. continue
  420. if should_exit:
  421. break
  422. if current_trace_id:
  423. menu_result = await interactive.show_menu(current_trace_id, current_sequence)
  424. if menu_result["action"] == "stop":
  425. break
  426. if menu_result["action"] == "continue":
  427. new_messages = menu_result.get("messages", [])
  428. if new_messages:
  429. initial_messages = new_messages
  430. run_config.after_sequence = menu_result.get("after_sequence")
  431. else:
  432. initial_messages = []
  433. run_config.after_sequence = None
  434. continue
  435. break
  436. except KeyboardInterrupt:
  437. logger.info("\n\n用户中断 (Ctrl+C)")
  438. if current_trace_id:
  439. await runner.stop(current_trace_id)
  440. finally:
  441. if current_trace_id:
  442. try:
  443. html_path = store.base_path / current_trace_id / "messages.html"
  444. await trace_to_html(current_trace_id, html_path, base_path=str(store.base_path))
  445. logger.info(f"\n✓ Messages 可视化已保存: {html_path}")
  446. except Exception as e:
  447. logger.error(f"\n⚠ 生成 HTML 失败: {e}")
  448. if final_response:
  449. logger.info("")
  450. logger.info("=" * 60)
  451. logger.info("Agent 响应:")
  452. logger.info("=" * 60)
  453. logger.info(final_response)
  454. logger.info("=" * 60)
  455. logger.info("")
  456. output_file = output_dir / "result.txt"
  457. with open(output_file, "w", encoding="utf-8") as f:
  458. f.write(final_response)
  459. logger.info(f"✓ 结果已保存到: {output_file}")
  460. logger.info("")
  461. if current_trace_id:
  462. html_path = store.base_path / current_trace_id / "messages.html"
  463. logger.info("=" * 60)
  464. logger.info("可视化:")
  465. logger.info("=" * 60)
  466. logger.info(f"1. 本地 HTML: {html_path}")
  467. logger.info("")
  468. logger.info("2. API Server:")
  469. logger.info(" python3 api_server.py")
  470. logger.info(" http://localhost:8000/api/traces")
  471. logger.info("")
  472. logger.info(f"3. Trace ID: {current_trace_id}")
  473. logger.info("=" * 60)
  474. if __name__ == "__main__":
  475. asyncio.run(main())