| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549 |
- """
- 示例(流程对齐版)
- 参考 examples/research/run.py:
- 1. 使用框架 InteractiveController 统一交互流程
- 2. 使用 config.py 管理运行参数
- 3. 保留 create 场景特有的 prompt 注入与详细消息打印
- """
- import argparse
- import asyncio
- import copy
- import json
- import os
- import sys
- from pathlib import Path
- from typing import Any
- import logging
- # Clash Verge TUN 模式兼容:禁止 httpx/urllib 自动检测系统 HTTP 代理
- os.environ.setdefault("no_proxy", "*")
- logger = logging.getLogger(__name__)
- # 添加项目根目录到 Python 路径
- sys.path.insert(0, str(Path(__file__).parent.parent.parent))
- from dotenv import load_dotenv
- load_dotenv()
- from agent.cli import InteractiveController
- from agent.core.presets import AgentPreset, register_preset
- from agent.core.runner import AgentRunner
- from agent.llm import create_openrouter_llm_call
- from agent.llm.prompts import SimplePrompt
- from agent.trace import FileSystemTraceStore, Message, Trace
- from agent.utils import setup_logging
- from examples.create.html import trace_to_html
- # 导入项目配置
- from config import DEBUG, LOG_FILE, LOG_LEVEL, RUN_CONFIG, SKILLS_DIR, TRACE_STORE_PATH
- def _format_json(obj: Any, indent: int = 2) -> str:
- """格式化 JSON 对象为字符串"""
- try:
- return json.dumps(obj, indent=indent, ensure_ascii=False)
- except (TypeError, ValueError):
- return str(obj)
- def _print_message_details(message: Message):
- """完整打印消息的详细信息"""
- logger.info("\n" + "=" * 80)
- logger.info(f"[Message #{message.sequence}] {message.role.upper()}")
- logger.info("=" * 80)
- if message.goal_id:
- logger.info(f"Goal ID: {message.goal_id}")
- if message.parent_sequence is not None:
- logger.info(f"Parent Sequence: {message.parent_sequence}")
- if message.tool_call_id:
- logger.info(f"Tool Call ID: {message.tool_call_id}")
- if message.role == "user":
- logger.info("\n[输入内容]")
- logger.info("-" * 80)
- if isinstance(message.content, str):
- logger.info(message.content)
- else:
- logger.info(_format_json(message.content))
- elif message.role == "assistant":
- content = message.content
- if isinstance(content, dict):
- text = content.get("text", "")
- tool_calls = content.get("tool_calls")
- if text:
- logger.info("\n[LLM 文本回复]")
- logger.info("-" * 80)
- logger.info(text)
- if tool_calls:
- logger.info(f"\n[工具调用] (共 {len(tool_calls)} 个)")
- logger.info("-" * 80)
- for idx, tc in enumerate(tool_calls, 1):
- func = tc.get("function", {})
- tool_name = func.get("name", "unknown")
- tool_id = tc.get("id", "unknown")
- arguments = func.get("arguments", {})
- logger.info(f"\n工具 #{idx}: {tool_name}")
- logger.info(f" Call ID: {tool_id}")
- logger.info(" 参数:")
- if isinstance(arguments, str):
- try:
- parsed_args = json.loads(arguments)
- logger.info(_format_json(parsed_args, indent=4))
- except json.JSONDecodeError:
- logger.info(f" {arguments}")
- else:
- logger.info(_format_json(arguments, indent=4))
- elif isinstance(content, str):
- logger.info("\n[LLM 文本回复]")
- logger.info("-" * 80)
- logger.info(content)
- else:
- logger.info("\n[内容]")
- logger.info("-" * 80)
- logger.info(_format_json(content))
- if message.finish_reason:
- logger.info(f"\n完成原因: {message.finish_reason}")
- elif message.role == "tool":
- content = message.content
- logger.info("\n[工具执行结果]")
- logger.info("-" * 80)
- if isinstance(content, dict):
- tool_name = content.get("tool_name", "unknown")
- result = content.get("result", content)
- logger.info(f"工具名称: {tool_name}")
- logger.info("\n返回结果:")
- if isinstance(result, str):
- logger.info(result)
- elif isinstance(result, list):
- for idx, item in enumerate(result, 1):
- if isinstance(item, dict) and item.get("type") == "image_url":
- logger.info(f" [{idx}] 图片 (base64, 已省略显示)")
- else:
- logger.info(f" [{idx}] {item}")
- else:
- logger.info(_format_json(result))
- else:
- logger.info(str(content) if content is not None else "(无内容)")
- elif message.role == "system":
- logger.info("\n[系统提示]")
- logger.info("-" * 80)
- if isinstance(message.content, str):
- logger.info(message.content)
- else:
- logger.info(_format_json(message.content))
- if message.prompt_tokens is not None or message.completion_tokens is not None:
- logger.info("\n[Token 使用]")
- logger.info("-" * 80)
- if message.prompt_tokens is not None:
- logger.info(f" 输入 Tokens: {message.prompt_tokens:,}")
- if message.completion_tokens is not None:
- logger.info(f" 输出 Tokens: {message.completion_tokens:,}")
- if message.reasoning_tokens is not None:
- logger.info(f" 推理 Tokens: {message.reasoning_tokens:,}")
- if message.cache_creation_tokens is not None:
- logger.info(f" 缓存创建 Tokens: {message.cache_creation_tokens:,}")
- if message.cache_read_tokens is not None:
- logger.info(f" 缓存读取 Tokens: {message.cache_read_tokens:,}")
- if message.tokens:
- logger.info(f" 总计 Tokens: {message.tokens:,}")
- if message.cost is not None:
- logger.info(f"\n[成本] ${message.cost:.6f}")
- if message.duration_ms is not None:
- logger.info(f"[执行时间] {message.duration_ms}ms")
- logger.info("=" * 80 + "\n")
- def _apply_prompt_placeholders(base_dir: Path, prompt: SimplePrompt, persona_dir: str = None):
- """把 PRD 文件内容和人设树数据注入 prompt 占位符。
- Args:
- base_dir: 基础目录
- prompt: SimplePrompt 对象
- persona_dir: 人设数据目录名,如 "家有大志"。如果为 None,则不替换树数据
- """
- # 替换 {{person_name}} 占位符
- if persona_dir:
- person_name_placeholder = "{{person_name}}"
- if "system" in prompt._messages and person_name_placeholder in prompt._messages["system"]:
- prompt._messages["system"] = prompt._messages["system"].replace(person_name_placeholder, persona_dir)
- logger.info(f" - 已替换 {{{{person_name}}}} 为: {persona_dir}")
- if "user" in prompt._messages and person_name_placeholder in prompt._messages["user"]:
- prompt._messages["user"] = prompt._messages["user"].replace(person_name_placeholder, persona_dir)
- logger.info(f" - 已替换 {{{{person_name}}}} 为: {persona_dir} (user)")
- system_md_path = base_dir / "PRD" / "system.md"
- if system_md_path.exists():
- system_content = system_md_path.read_text(encoding="utf-8")
- if "system" in prompt._messages and "{system}" in prompt._messages["system"]:
- prompt._messages["system"] = prompt._messages["system"].replace("{system}", system_content)
- else:
- logger.warning(f" - 警告: system.md 文件不存在: {system_md_path}")
- # 优先使用 v2 版本,如果不存在则使用原版本
- create_process_md_path = base_dir / "PRD" / "create_process_v2.md"
- if not create_process_md_path.exists():
- create_process_md_path = base_dir / "PRD" / "create_process.md"
- if create_process_md_path.exists():
- create_process_content = create_process_md_path.read_text(encoding="utf-8")
- if "system" in prompt._messages and "{create_process}" in prompt._messages["system"]:
- prompt._messages["system"] = prompt._messages["system"].replace("{create_process}", create_process_content)
- logger.info(f" - 已替换 {create_process_md_path.name} 内容到 prompt")
- else:
- logger.warning(" - 警告: prompt 中未找到 {create_process} 占位符")
- else:
- logger.warning(f" - 警告: create_process.md 文件不存在: {create_process_md_path}")
- # 替换人设树数据
- if persona_dir:
- tree_dir = base_dir / "data" / persona_dir / "tree"
- if tree_dir.exists():
- # 读取三个树文件
- tree_files = {
- "形式_point_tree_how": tree_dir / "形式_point_tree_how.json",
- "实质_point_tree_how": tree_dir / "实质_point_tree_how.json",
- "意图_point_tree_how": tree_dir / "意图_point_tree_how.json"
- }
- for var_name, tree_path in tree_files.items():
- if tree_path.exists():
- tree_content = tree_path.read_text(encoding="utf-8")
- placeholder = "{{" + var_name + "}}"
- # 在 system 消息中替换
- if "system" in prompt._messages and placeholder in prompt._messages["system"]:
- prompt._messages["system"] = prompt._messages["system"].replace(placeholder, tree_content)
- logger.info(f" - 已替换 {var_name} 数据到 prompt")
- # 在 user 消息中替换
- if "user" in prompt._messages and placeholder in prompt._messages["user"]:
- prompt._messages["user"] = prompt._messages["user"].replace(placeholder, tree_content)
- logger.info(f" - 已替换 {var_name} 数据到 prompt (user)")
- else:
- logger.warning(f" - 警告: 树文件不存在: {tree_path}")
- else:
- logger.warning(f" - 警告: 人设树目录不存在: {tree_dir}")
- input_md_path = base_dir / "PRD" / "input.md"
- if input_md_path.exists():
- user_content = input_md_path.read_text(encoding="utf-8")
- if "user" in prompt._messages and "{input}" in prompt._messages["user"]:
- prompt._messages["user"] = prompt._messages["user"].replace("{input}", user_content)
- logger.info(" - 已替换 input.md 内容到 prompt")
- else:
- logger.warning(" - 警告: prompt 中未找到 {input} 占位符")
- else:
- logger.warning(f" - 警告: input.md 文件不存在: {input_md_path}")
- output_md_path = base_dir / "PRD" / "output.md"
- if output_md_path.exists():
- output_content = output_md_path.read_text(encoding="utf-8")
- if "user" in prompt._messages and "{output}" in prompt._messages["user"]:
- prompt._messages["user"] = prompt._messages["user"].replace("{output}", output_content)
- logger.info(" - 已替换 output.md 内容到 prompt")
- else:
- logger.warning(" - 警告: prompt 中未找到 {output} 占位符")
- else:
- logger.warning(f" - 警告: output.md 文件不存在: {output_md_path}")
- async def main():
- parser = argparse.ArgumentParser(description="任务 (Agent 模式 + 交互增强)")
- parser.add_argument(
- "--trace",
- type=str,
- default=None,
- help="已有的 Trace ID,用于恢复继续执行(不指定则新建)",
- )
- parser.add_argument(
- "--persona",
- type=str,
- default=None,
- help="人设数据目录名,如 '家有大志'。用于读取 data/{目录名}/tree 下的树数据",
- )
- args = parser.parse_args()
- base_dir = Path(__file__).parent
- prompt_path = base_dir / "create.prompt"
- output_dir = base_dir / "output_1"
- output_dir.mkdir(exist_ok=True)
- setup_logging(level=LOG_LEVEL, file=LOG_FILE)
- logger.info("2. 加载 presets...")
- presets_path = base_dir / "presets.json"
- if presets_path.exists():
- with open(presets_path, "r", encoding="utf-8") as f:
- project_presets = json.load(f)
- for name, cfg in project_presets.items():
- register_preset(name, AgentPreset(**cfg))
- logger.info(f" - 已加载项目 presets: {list(project_presets.keys())}")
- logger.info("3. 加载 prompt...")
- prompt = SimplePrompt(prompt_path)
- _apply_prompt_placeholders(base_dir, prompt, persona_dir=args.persona)
- logger.info("\n替换后的 prompt:")
- logger.info("=" * 60)
- logger.info("System:")
- logger.info("-" * 60)
- logger.info(prompt._messages.get("system", ""))
- logger.info("=" * 60)
- if "user" in prompt._messages:
- logger.info("\nUser:")
- logger.info("-" * 60)
- logger.info(prompt._messages["user"])
- logger.info("=" * 60)
- logger.info("")
- logger.info("4. 构建任务消息...")
- messages = prompt.build_messages()
- logger.info("5. 创建 Agent Runner...")
- logger.info(" - 加载自定义工具: topic_search")
- import examples.create.tool # noqa: F401
- model_from_prompt = prompt.config.get("model")
- model_from_config = RUN_CONFIG.model
- default_model = f"anthropic/{model_from_config}" if "/" not in model_from_config else model_from_config
- model = model_from_prompt or default_model
- skills_dir = str((base_dir / SKILLS_DIR).resolve()) if not Path(SKILLS_DIR).is_absolute() else SKILLS_DIR
- logger.info(f" - Skills 目录: {skills_dir}")
- logger.info(f" - 模型: {model}")
- store = FileSystemTraceStore(base_path=TRACE_STORE_PATH)
- runner = AgentRunner(
- trace_store=store,
- llm_call=create_openrouter_llm_call(model=model),
- skills_dir=skills_dir,
- debug=DEBUG,
- )
- interactive = InteractiveController(
- runner=runner,
- store=store,
- enable_stdin_check=True,
- )
- task_name = RUN_CONFIG.name or base_dir.name
- logger.info("=" * 60)
- logger.info(task_name)
- logger.info("=" * 60)
- logger.info("💡 交互提示:")
- logger.info(" - 执行过程中输入 'p' 或 'pause' 暂停并进入交互模式")
- logger.info(" - 执行过程中输入 'q' 或 'quit' 停止执行")
- logger.info("=" * 60)
- logger.info("")
- resume_trace_id = args.trace
- if resume_trace_id:
- existing_trace = await store.get_trace(resume_trace_id)
- if not existing_trace:
- logger.error(f"\n错误: Trace 不存在: {resume_trace_id}")
- sys.exit(1)
- logger.info(f"恢复已有 Trace: {resume_trace_id[:8]}...")
- logger.info(f" - 状态: {existing_trace.status}")
- logger.info(f" - 消息数: {existing_trace.total_messages}")
- else:
- logger.info("启动新 Agent...")
- logger.info("")
- final_response = ""
- current_trace_id = resume_trace_id
- current_sequence = 0
- should_exit = False
- try:
- run_config = copy.deepcopy(RUN_CONFIG)
- run_config.model = model
- run_config.temperature = float(prompt.config.get("temperature", run_config.temperature))
- run_config.max_iterations = int(prompt.config.get("max_iterations", run_config.max_iterations))
- if resume_trace_id:
- initial_messages = None
- run_config.trace_id = resume_trace_id
- else:
- initial_messages = messages
- run_config.name = "社交媒体内容解构、建构、评估任务"
- while not should_exit:
- if current_trace_id:
- run_config.trace_id = current_trace_id
- final_response = ""
- if current_trace_id and initial_messages is None:
- check_trace = await store.get_trace(current_trace_id)
- if check_trace and check_trace.status in ("completed", "failed"):
- if check_trace.status == "completed":
- logger.info("\n[Trace] ✅ 已完成")
- logger.info(f" - Total messages: {check_trace.total_messages}")
- logger.info(f" - Total cost: ${check_trace.total_cost:.4f}")
- else:
- logger.error(f"\n[Trace] ❌ 已失败: {check_trace.error_message}")
- current_sequence = check_trace.head_sequence
- menu_result = await interactive.show_menu(current_trace_id, current_sequence)
- if menu_result["action"] == "stop":
- break
- if menu_result["action"] == "continue":
- new_messages = menu_result.get("messages", [])
- if new_messages:
- initial_messages = new_messages
- run_config.after_sequence = menu_result.get("after_sequence")
- else:
- initial_messages = []
- run_config.after_sequence = None
- continue
- break
- initial_messages = []
- logger.info(f"{'▶️ 开始执行...' if not current_trace_id else '▶️ 继续执行...'}")
- paused = False
- try:
- async for item in runner.run(messages=initial_messages, config=run_config):
- cmd = interactive.check_stdin()
- if cmd == "pause":
- logger.info("\n⏸️ 正在暂停执行...")
- if current_trace_id:
- await runner.stop(current_trace_id)
- await asyncio.sleep(0.5)
- menu_result = await interactive.show_menu(current_trace_id, current_sequence)
- if menu_result["action"] == "stop":
- should_exit = True
- paused = True
- break
- if menu_result["action"] == "continue":
- new_messages = menu_result.get("messages", [])
- if new_messages:
- initial_messages = new_messages
- after_seq = menu_result.get("after_sequence")
- if after_seq is not None:
- run_config.after_sequence = after_seq
- else:
- initial_messages = []
- run_config.after_sequence = None
- paused = True
- break
- elif cmd == "quit":
- logger.info("\n🛑 用户请求停止...")
- if current_trace_id:
- await runner.stop(current_trace_id)
- should_exit = True
- break
- if isinstance(item, Trace):
- current_trace_id = item.trace_id
- if item.status == "running":
- logger.info(f"[Trace] 开始: {item.trace_id[:8]}...")
- elif item.status == "completed":
- logger.info("\n[Trace] ✅ 完成")
- logger.info(f" - Total messages: {item.total_messages}")
- logger.info(f" - Total tokens: {item.total_tokens}")
- logger.info(f" - Total cost: ${item.total_cost:.4f}")
- elif item.status == "failed":
- logger.error(f"\n[Trace] ❌ 失败: {item.error_message}")
- elif item.status == "stopped":
- logger.info("\n[Trace] ⏸️ 已停止")
- elif isinstance(item, Message):
- current_sequence = item.sequence
- _print_message_details(item)
- if item.role == "assistant":
- content = item.content
- if isinstance(content, dict):
- text = content.get("text", "")
- tool_calls = content.get("tool_calls")
- if text and not tool_calls:
- final_response = text
- except Exception as e:
- logger.error(f"\n执行出错: {e}")
- logger.exception("Exception details:")
- if paused:
- if should_exit:
- break
- continue
- if should_exit:
- break
- if current_trace_id:
- menu_result = await interactive.show_menu(current_trace_id, current_sequence)
- if menu_result["action"] == "stop":
- break
- if menu_result["action"] == "continue":
- new_messages = menu_result.get("messages", [])
- if new_messages:
- initial_messages = new_messages
- run_config.after_sequence = menu_result.get("after_sequence")
- else:
- initial_messages = []
- run_config.after_sequence = None
- continue
- break
- except KeyboardInterrupt:
- logger.info("\n\n用户中断 (Ctrl+C)")
- if current_trace_id:
- await runner.stop(current_trace_id)
- finally:
- if current_trace_id:
- try:
- html_path = store.base_path / current_trace_id / "messages.html"
- await trace_to_html(current_trace_id, html_path, base_path=str(store.base_path))
- logger.info(f"\n✓ Messages 可视化已保存: {html_path}")
- except Exception as e:
- logger.error(f"\n⚠ 生成 HTML 失败: {e}")
- if final_response:
- logger.info("")
- logger.info("=" * 60)
- logger.info("Agent 响应:")
- logger.info("=" * 60)
- logger.info(final_response)
- logger.info("=" * 60)
- logger.info("")
- output_file = output_dir / "result.txt"
- with open(output_file, "w", encoding="utf-8") as f:
- f.write(final_response)
- logger.info(f"✓ 结果已保存到: {output_file}")
- logger.info("")
- if current_trace_id:
- html_path = store.base_path / current_trace_id / "messages.html"
- logger.info("=" * 60)
- logger.info("可视化:")
- logger.info("=" * 60)
- logger.info(f"1. 本地 HTML: {html_path}")
- logger.info("")
- logger.info("2. API Server:")
- logger.info(" python3 api_server.py")
- logger.info(" http://localhost:8000/api/traces")
- logger.info("")
- logger.info(f"3. Trace ID: {current_trace_id}")
- logger.info("=" * 60)
- if __name__ == "__main__":
- asyncio.run(main())
|