run_single.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. """
  2. 内容寻找 Agent 执行入口
  3. 日志设计:
  4. - agent.log: JSONL 格式的完整执行追踪(无截断),用于 HTML 可视化
  5. - 控制台: 人类可读的简要输出
  6. """
  7. import logging
  8. import sys
  9. import os
  10. import json
  11. import traceback as tb_mod
  12. from typing import Dict, Any, Optional, List
  13. from pathlib import Path
  14. from datetime import datetime
  15. PROJECT_ROOT = Path(__file__).resolve().parent
  16. log_dir = PROJECT_ROOT / '.cache'
  17. log_dir.mkdir(exist_ok=True)
  18. TRACE_LOG_PATH = log_dir / 'agent.log'
  19. # ============================================================
  20. # TraceWriter: 结构化 JSONL 追踪日志(写入 agent.log,无截断)
  21. # ============================================================
  22. class TraceWriter:
  23. """将 Agent 执行的每一步写为 JSONL 事件到 agent.log,不做任何截断。"""
  24. def __init__(self, path: Path):
  25. self._file = open(path, 'w', encoding='utf-8')
  26. self._iteration = 0
  27. def _ts(self) -> str:
  28. return datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
  29. def _write(self, event: dict):
  30. event["ts"] = self._ts()
  31. self._file.write(json.dumps(event, ensure_ascii=False) + '\n')
  32. self._file.flush()
  33. # --- 任务生命周期 ---
  34. def log_init(self, query: str, demand_id: int, model: str, trace_id: str = None):
  35. self._write({
  36. "type": "init",
  37. "query": query,
  38. "demand_id": demand_id,
  39. "model": model,
  40. "trace_id": trace_id,
  41. })
  42. def log_complete(self, trace_id: str, status: str, error: str = None):
  43. self._write({
  44. "type": "complete",
  45. "trace_id": trace_id,
  46. "status": status,
  47. "error": error,
  48. "total_iterations": self._iteration,
  49. })
  50. # --- 框架日志 ---
  51. def log_framework(self, logger_name: str, level: str, message: str):
  52. self._write({
  53. "type": "framework",
  54. "logger": logger_name,
  55. "level": level,
  56. "msg": message,
  57. })
  58. # --- Agent 思考 / LLM 输出 ---
  59. def log_assistant(self, text: str, tool_calls: List[dict] = None,
  60. reasoning: str = None, tokens: dict = None):
  61. """记录 LLM 的完整输出(思考文本 + 工具调用,不截断)。"""
  62. self._iteration += 1
  63. parsed_calls = []
  64. for tc in (tool_calls or []):
  65. func = tc.get("function", {})
  66. name = func.get("name", "unknown")
  67. args_str = func.get("arguments", "{}")
  68. try:
  69. params = json.loads(args_str)
  70. except (json.JSONDecodeError, TypeError):
  71. params = args_str
  72. parsed_calls.append({
  73. "name": name,
  74. "params": params,
  75. "call_id": tc.get("id", ""),
  76. })
  77. self._write({
  78. "type": "assistant",
  79. "iteration": self._iteration,
  80. "text": text or "",
  81. "tool_calls": parsed_calls,
  82. "reasoning": reasoning or "",
  83. "tokens": tokens or {},
  84. })
  85. # --- 工具结果 ---
  86. def log_tool_result(self, tool_name: str, result: Any, call_id: str = ""):
  87. """记录工具的完整返回(不截断)。"""
  88. if isinstance(result, list):
  89. texts = [
  90. p.get("text", "") for p in result
  91. if isinstance(p, dict) and p.get("type") == "text"
  92. ]
  93. output = "\n".join(texts) if texts else json.dumps(result, ensure_ascii=False)
  94. elif isinstance(result, str):
  95. output = result
  96. else:
  97. output = str(result)
  98. self._write({
  99. "type": "tool_result",
  100. "name": tool_name,
  101. "call_id": call_id,
  102. "output": output,
  103. })
  104. # --- 错误 ---
  105. def log_error(self, message: str, traceback_str: str = ""):
  106. self._write({
  107. "type": "error",
  108. "msg": message,
  109. "traceback": traceback_str,
  110. })
  111. def close(self):
  112. if self._file and not self._file.closed:
  113. self._file.close()
  114. class JsonlLogHandler(logging.Handler):
  115. """将 Python logging 记录路由到 TraceWriter(JSONL 格式)。"""
  116. def __init__(self, trace_writer: TraceWriter):
  117. super().__init__()
  118. self.trace_writer = trace_writer
  119. def emit(self, record: logging.LogRecord):
  120. try:
  121. self.trace_writer.log_framework(
  122. record.name,
  123. record.levelname,
  124. record.getMessage(),
  125. )
  126. except Exception:
  127. pass
  128. # ============================================================
  129. # 控制台日志(人类可读)
  130. # ============================================================
  131. console_handler = logging.StreamHandler(sys.stdout)
  132. console_handler.setFormatter(
  133. logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  134. )
  135. logging.basicConfig(level=logging.INFO, handlers=[console_handler], force=True)
  136. logger = logging.getLogger(__name__)
  137. # ============================================================
  138. # 第三方 / 项目依赖导入(放在 logging 配置之后)
  139. # ============================================================
  140. from dotenv import load_dotenv
  141. load_dotenv()
  142. from tools import fetch_account_article_list, fetch_weixin_account, weixin_search
  143. from agent import AgentRunner, RunConfig, FileSystemTraceStore, Trace, Message
  144. from agent.llm import create_openrouter_llm_call
  145. from agent.llm.prompts import SimplePrompt
  146. from agent.tools.builtin.knowledge import KnowledgeConfig
  147. DEFAULT_QUERY = "伊朗、以色列、和平是永恒的主题"
  148. DEFAULT_DEMAND_ID = 1
  149. # ============================================================
  150. # 工具函数
  151. # ============================================================
  152. def _normalize_ascii_double_quotes(text: str) -> str:
  153. """将字符串中的 ASCII 双引号 `"` 规范化为中文双引号 `"`、`"`。"""
  154. if '"' not in text:
  155. return text
  156. chars: list[str] = []
  157. open_quote = True
  158. for ch in text:
  159. if ch == '"':
  160. chars.append("\u201c" if open_quote else "\u201d")
  161. open_quote = not open_quote
  162. else:
  163. chars.append(ch)
  164. return "".join(chars)
  165. def _sanitize_json_strings(value: Any) -> Any:
  166. if isinstance(value, str):
  167. return _normalize_ascii_double_quotes(value)
  168. if isinstance(value, list):
  169. return [_sanitize_json_strings(v) for v in value]
  170. if isinstance(value, dict):
  171. return {k: _sanitize_json_strings(v) for k, v in value.items()}
  172. return value
  173. def _sanitize_output_json(output_json_path: Path) -> None:
  174. if not output_json_path.exists():
  175. logger.warning(f"未找到 output.json,跳过清洗: {output_json_path}")
  176. return
  177. try:
  178. data = json.loads(output_json_path.read_text(encoding="utf-8"))
  179. except Exception as e:
  180. logger.warning(f"output.json 解析失败,跳过清洗: {e}")
  181. return
  182. cleaned = _sanitize_json_strings(data)
  183. output_json_path.write_text(
  184. json.dumps(cleaned, ensure_ascii=False, indent=2),
  185. encoding="utf-8"
  186. )
  187. logger.info(f"已完成 output.json 引号清洗: {output_json_path}")
  188. # ============================================================
  189. # 控制台流式输出(简洁版)
  190. # ============================================================
  191. def _print_assistant(text: str, tool_calls: list):
  192. """向控制台打印 Agent 输出的摘要。"""
  193. if text:
  194. logger.info("\n%s", text)
  195. for tc in (tool_calls or []):
  196. name = tc.get("function", {}).get("name", "unknown")
  197. if name not in ("goal", "get_current_context"):
  198. logger.info("[工具] %s", name)
  199. def _print_tool_result(tool_name: str):
  200. """向控制台打印工具结果标记。"""
  201. if tool_name not in ("goal", "get_current_context"):
  202. logger.info("[结果] %s ✓", tool_name)
  203. # ============================================================
  204. # Agent 执行
  205. # ============================================================
  206. async def run_agent(
  207. query: Optional[str] = None,
  208. demand_id: Optional[int] = None,
  209. stream_output: bool = True,
  210. ) -> Dict[str, Any]:
  211. query = query or DEFAULT_QUERY
  212. demand_id = demand_id or DEFAULT_DEMAND_ID
  213. # 创建 TraceWriter → agent.log(JSONL,完整无截断)
  214. tw = TraceWriter(TRACE_LOG_PATH)
  215. # 将 Python logging 也路由到 JSONL
  216. jsonl_handler = JsonlLogHandler(tw)
  217. jsonl_handler.setLevel(logging.DEBUG)
  218. logging.getLogger().addHandler(jsonl_handler)
  219. prompt_path = PROJECT_ROOT / "content_finder.md"
  220. prompt = SimplePrompt(prompt_path)
  221. output_dir = str(PROJECT_ROOT / "output")
  222. demand_id_str = str(demand_id) if demand_id is not None else ""
  223. messages = prompt.build_messages(query=query, output_dir=output_dir, demand_id=demand_id_str)
  224. api_key = os.getenv("OPEN_ROUTER_API_KEY")
  225. if not api_key:
  226. raise ValueError("OPEN_ROUTER_API_KEY 未设置")
  227. model_name = prompt.config.get("model", "sonnet-4.6")
  228. model = os.getenv("MODEL", f"anthropic/claude-{model_name}")
  229. temperature = float(prompt.config.get("temperature", 0.3))
  230. max_iterations = 30
  231. trace_dir = str(PROJECT_ROOT / "traces")
  232. skills_dir = str(PROJECT_ROOT / "skills")
  233. Path(trace_dir).mkdir(parents=True, exist_ok=True)
  234. store = FileSystemTraceStore(base_path=trace_dir)
  235. allowed_tools = [
  236. "weixin_search",
  237. "fetch_weixin_account",
  238. "fetch_account_article_list",
  239. "fetch_article_detail",
  240. ]
  241. runner = AgentRunner(
  242. llm_call=create_openrouter_llm_call(model=model),
  243. trace_store=store,
  244. skills_dir=skills_dir,
  245. )
  246. config = RunConfig(
  247. name="内容寻找",
  248. model=model,
  249. temperature=temperature,
  250. max_iterations=max_iterations,
  251. tools=allowed_tools,
  252. extra_llm_params={"max_tokens": 8192},
  253. knowledge=KnowledgeConfig(
  254. enable_extraction=False,
  255. enable_completion_extraction=False,
  256. enable_injection=False,
  257. )
  258. )
  259. tw.log_init(query, demand_id, model)
  260. trace_id = None
  261. try:
  262. async for item in runner.run(messages=messages, config=config):
  263. # ---------- Trace 对象 ----------
  264. if isinstance(item, Trace):
  265. trace_id = item.trace_id
  266. tw._write({"type": "trace_status", "trace_id": trace_id, "status": item.status})
  267. if item.status == "completed":
  268. if trace_id:
  269. output_json_path = Path(output_dir) / trace_id / "output.json"
  270. _sanitize_output_json(output_json_path)
  271. tw.log_complete(trace_id, "completed")
  272. logger.info(f"Agent 执行完成: trace_id={trace_id}")
  273. return {"trace_id": trace_id, "status": "completed"}
  274. elif item.status == "failed":
  275. tw.log_complete(trace_id, "failed", item.error_message)
  276. logger.error(f"Agent 执行失败: {item.error_message}")
  277. return {"trace_id": trace_id, "status": "failed", "error": item.error_message}
  278. # ---------- Message 对象 ----------
  279. elif isinstance(item, Message):
  280. # --- Assistant 消息(思考 + 工具调用)---
  281. if item.role == "assistant":
  282. content = item.content
  283. if isinstance(content, dict):
  284. text = content.get("text", "")
  285. tool_calls = content.get("tool_calls", [])
  286. reasoning = content.get("reasoning_content", "")
  287. # JSONL: 完整记录,不截断
  288. tw.log_assistant(
  289. text=text,
  290. tool_calls=tool_calls,
  291. reasoning=reasoning,
  292. tokens={
  293. "prompt": getattr(item, "prompt_tokens", None),
  294. "completion": getattr(item, "completion_tokens", None),
  295. },
  296. )
  297. # 控制台:简要输出
  298. if stream_output:
  299. _print_assistant(text, tool_calls)
  300. elif isinstance(content, str) and content:
  301. tw.log_assistant(text=content)
  302. if stream_output:
  303. logger.info("\n%s", content)
  304. # --- Tool 消息(工具返回)---
  305. elif item.role == "tool":
  306. content = item.content
  307. if isinstance(content, dict):
  308. tool_name = content.get("tool_name", "unknown")
  309. result = content.get("result", "")
  310. error = content.get("error")
  311. # JSONL: 完整记录,不截断
  312. if error:
  313. tw.log_error(
  314. message=f"Tool {tool_name}: {error}",
  315. )
  316. else:
  317. tw.log_tool_result(
  318. tool_name=tool_name,
  319. result=result,
  320. call_id=item.tool_call_id or "",
  321. )
  322. # 控制台:简要标记
  323. if stream_output:
  324. _print_tool_result(tool_name)
  325. # 循环正常结束但未返回
  326. tw.log_complete(trace_id, "failed", "Agent 异常退出(循环结束未返回结果)")
  327. return {"trace_id": trace_id, "status": "failed", "error": "Agent 异常退出"}
  328. except KeyboardInterrupt:
  329. logger.info("用户中断")
  330. tw.log_complete(trace_id, "interrupted", "用户中断")
  331. if stream_output:
  332. logger.info("用户中断")
  333. return {"trace_id": trace_id, "status": "failed", "error": "用户中断"}
  334. except Exception as e:
  335. tb_str = tb_mod.format_exc()
  336. logger.error(f"Agent 执行异常: {e}", exc_info=True)
  337. tw.log_error(str(e), tb_str)
  338. tw.log_complete(trace_id, "failed", str(e))
  339. if stream_output:
  340. logger.error("执行失败: %s", e)
  341. return {"trace_id": trace_id, "status": "failed", "error": str(e)}
  342. finally:
  343. logging.getLogger().removeHandler(jsonl_handler)
  344. tw.close()
  345. async def main():
  346. try:
  347. result = await run_agent(query=None, demand_id=None, stream_output=True)
  348. if result["status"] == "completed":
  349. logger.info(f"[完成] trace_id={result['trace_id']}")
  350. else:
  351. logger.error(f"[失败] trace_id={result.get('trace_id')}, 错误: {result.get('error')}")
  352. sys.exit(1)
  353. except KeyboardInterrupt:
  354. logger.info("用户中断")
  355. except Exception as e:
  356. logger.error(f"执行失败: {e}", exc_info=True)
  357. sys.exit(1)
  358. if __name__ == "__main__":
  359. import asyncio
  360. asyncio.run(main())