run_search_agent.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. """
  2. Search Agent Harness — 约束驱动的搜索 Agent 入口。
  3. Harness Engineering 分层:
  4. 1. Budget Harness — 显式限定运行预算(超时、迭代上限、召回上限)
  5. 2. Planner Harness — 启动前打印运行计划,明确每阶段目标与约束
  6. 3. Observer Harness — 结构化进度回调,暴露关键检查点状态
  7. 4. Fallback Harness — DB 策略失败 / API Key 缺失的显式降级路径
  8. 前置:
  9. - OPEN_ROUTER_API_KEY
  10. - 可选:SEARCH_AGENT_DB_* 与表 search_agent_strategy(见 docs/search_agent_strategy.sql)
  11. 环境变量:
  12. - PIPELINE_QUERY / 默认 "伊朗、以色列、和平是永恒的主题"
  13. - PIPELINE_DEMAND_ID / 默认 "1"
  14. - PIPELINE_TIMEOUT / 整个 Agent 超时秒数,默认 1800(30 分钟)
  15. - PIPELINE_TARGET_COUNT / 目标文章数,默认取 RuntimePipelineConfig
  16. """
  17. from __future__ import annotations
  18. import asyncio
  19. import logging
  20. import os
  21. import shutil
  22. import sys
  23. import tempfile
  24. import time
  25. from dataclasses import dataclass, field
  26. from typing import Optional
  27. from uuid import uuid4
  28. from dotenv import load_dotenv
  29. from src.domain.search.core import SearchAgentCore
  30. from src.domain.search.policy import SearchAgentPolicy
  31. load_dotenv()
  32. # ── 日志级别由环境变量控制 ────────────
  33. _LOG_LEVEL = os.getenv("LOG_LEVEL", "DEBUG").upper()
  34. _CONSOLE_LEVEL = os.getenv("CONSOLE_LOG_LEVEL", "INFO").upper()
  35. _LOG_FMT = "%(asctime)s | %(levelname)-7s | %(name)s | %(message)s"
  36. _LOG_DATEFMT = "%Y-%m-%d %H:%M:%S"
  37. # 全局文件 handler 引用,供 main() 移动日志文件
  38. _file_handler: Optional[logging.FileHandler] = None
  39. _tmp_log_path: Optional[str] = None
  40. def _setup_logging() -> None:
  41. """
  42. 配置双通道日志:console(INFO)+ file(DEBUG)。
  43. 全量日志写入临时文件,pipeline 完成后移入 trace 目录。
  44. """
  45. global _file_handler, _tmp_log_path
  46. root = logging.getLogger()
  47. root.setLevel(getattr(logging, _LOG_LEVEL, logging.DEBUG))
  48. formatter = logging.Formatter(fmt=_LOG_FMT, datefmt=_LOG_DATEFMT)
  49. console = logging.StreamHandler(sys.__stdout__)
  50. console.setLevel(getattr(logging, _CONSOLE_LEVEL, logging.INFO))
  51. console.setFormatter(formatter)
  52. root.addHandler(console)
  53. tmp = tempfile.NamedTemporaryFile(
  54. delete=False, suffix=".log", prefix="search_agent_", mode="w", encoding="utf-8",
  55. )
  56. _tmp_log_path = tmp.name
  57. tmp.close()
  58. _file_handler = logging.FileHandler(_tmp_log_path, mode="w", encoding="utf-8")
  59. _file_handler.setLevel(logging.DEBUG)
  60. _file_handler.setFormatter(formatter)
  61. root.addHandler(_file_handler)
  62. for noisy in ("httpx", "httpcore", "urllib3", "asyncio"):
  63. logging.getLogger(noisy).setLevel(logging.WARNING)
  64. # agent 内核日志不写入全量日志文件(减少噪音)
  65. # 过滤 agent.core.runner / agent.llm.* / agent.tools.* / agent.trace.* 等
  66. class _AgentLogFilter(logging.Filter):
  67. def filter(self, record: logging.LogRecord) -> bool:
  68. return not record.name.startswith("agent.")
  69. _file_handler.addFilter(_AgentLogFilter())
  70. _setup_logging()
  71. logger = logging.getLogger(__name__)
  72. # ─────────────────────────────────────────────
  73. # 1. Budget Harness — 运行预算约束
  74. # ─────────────────────────────────────────────
  75. @dataclass
  76. class AgentBudget:
  77. """
  78. 显式声明 Agent 可消耗的资源上限。
  79. 约束驱动原则:
  80. - 所有上限必须在启动前确定,不允许在运行中隐式扩张。
  81. - 超时由 harness 层统一兜底,不依赖各 Stage 自己的超时。
  82. """
  83. timeout_seconds: int = 1800 # 整体超时(30 分钟)
  84. max_target_count: int = 10 # 单次运行最多产出文章数(防止无限扩张)
  85. max_fallback_rounds: int = 1 # content_search gate fallback 最大轮次(防止死循环)
  86. @classmethod
  87. def from_env(cls) -> "AgentBudget":
  88. return cls(
  89. timeout_seconds=int(os.getenv("PIPELINE_TIMEOUT", "1800")),
  90. max_target_count=int(os.getenv("PIPELINE_MAX_TARGET_COUNT", "10")),
  91. max_fallback_rounds=int(os.getenv("PIPELINE_MAX_FALLBACK_ROUNDS", "1")),
  92. )
  93. def validate(self) -> None:
  94. """前置断言:预算参数必须在合理范围内。"""
  95. if self.timeout_seconds < 30:
  96. raise ValueError(f"timeout_seconds 至少 30 秒,当前: {self.timeout_seconds}")
  97. if self.max_target_count < 1 or self.max_target_count > 200:
  98. raise ValueError(f"max_target_count 须在 [1, 200],当前: {self.max_target_count}")
  99. if self.max_fallback_rounds < 0 or self.max_fallback_rounds > 5:
  100. raise ValueError(f"max_fallback_rounds 须在 [0, 5],当前: {self.max_fallback_rounds}")
  101. # ─────────────────────────────────────────────
  102. # 2. Observer Harness — 结构化运行摘要
  103. # ─────────────────────────────────────────────
  104. @dataclass
  105. class RunSummary:
  106. """
  107. Agent 运行后的结构化摘要(非裸日志)。
  108. 设计意图:
  109. - 调用方可检查 success / error_message 决定后续动作。
  110. - 关键指标(candidate_count / filtered_count)可接入告警。
  111. """
  112. success: bool
  113. query: str
  114. demand_id: str
  115. policy_source: str = "unknown" # "db" | "default" | "override"
  116. trace_id: Optional[str] = None
  117. output_file: str = ""
  118. candidate_count: int = 0
  119. filtered_count: int = 0
  120. account_count: int = 0
  121. elapsed_seconds: float = 0.0
  122. error_message: str = ""
  123. stage_history: list = field(default_factory=list)
  124. def log(self) -> None:
  125. """结构化打印运行摘要。"""
  126. status = "✅ 成功" if self.success else "❌ 失败"
  127. logger.info("=" * 60)
  128. logger.info("Agent 运行摘要 %s", status)
  129. logger.info(" query : %s", self.query)
  130. logger.info(" demand_id : %s", self.demand_id)
  131. logger.info(" policy_source: %s", self.policy_source)
  132. logger.info(" trace_id : %s", self.trace_id)
  133. logger.info(" output_file : %s", self.output_file)
  134. logger.info(" 候选文章数 : %d", self.candidate_count)
  135. logger.info(" 入选文章数 : %d", self.filtered_count)
  136. logger.info(" 账号数 : %d", self.account_count)
  137. logger.info(" 耗时 : %.1f 秒", self.elapsed_seconds)
  138. if self.error_message:
  139. logger.error(" 错误信息 : %s", self.error_message)
  140. if self.stage_history:
  141. logger.info(" 阶段历史:")
  142. for record in self.stage_history:
  143. status_flag = "✓" if record.get("status") == "completed" else "✗"
  144. logger.info(
  145. " %s %-28s attempt=%d",
  146. status_flag,
  147. record.get("stage_name", "?"),
  148. record.get("attempt", 1),
  149. )
  150. logger.info("=" * 60)
  151. # ─────────────────────────────────────────────
  152. # 3. Planner Harness — 启动前打印运行计划
  153. # ─────────────────────────────────────────────
  154. def print_run_plan(query: str, demand_id: str, budget: AgentBudget, trace_id: str) -> dict:
  155. """
  156. 在 Agent 启动前打印结构化运行计划,并返回计划数据供 trace 使用。
  157. 目的:
  158. - 使运行意图可见、可审计,便于调试和追溯。
  159. - 明确各阶段目标与约束,防止"黑盒"执行。
  160. """
  161. logger.info("=" * 60)
  162. logger.info("▶ Search Agent 运行计划")
  163. logger.info(" Trace ID : %s", trace_id)
  164. logger.info(" Query : %s", query)
  165. logger.info(" Demand ID : %s", demand_id or "(未指定,使用 default 策略)")
  166. logger.info(" 超时上限 : %d 秒", budget.timeout_seconds)
  167. logger.info(" 目标文章上限 : %d 篇", budget.max_target_count)
  168. logger.info(" 最大补召回轮次: %d 轮", budget.max_fallback_rounds)
  169. logger.info("")
  170. logger.info(" 阶段规划:")
  171. logger.info(" 1. [demand_analysis ] ← 需求理解,产出搜索策略(无工具调用)")
  172. logger.info(" 2. [content_search ] ← 按关键词召回候选文章")
  173. logger.info(" └─ Gate: SearchCompletenessGate — 候选不足则 abort")
  174. logger.info(" 3. [hard_filter ] ← 去重 + URL / 时间基础校验")
  175. logger.info(" 4. [coarse_filter ] ← LLM 标题语义粗筛")
  176. logger.info(" 5. [quality_filter ] ← 数据指标评分 + LLM 正文精排")
  177. logger.info(" └─ Gate: FilterSufficiencyGate — 不足则回退补召回(最多 %d 轮)",
  178. budget.max_fallback_rounds)
  179. logger.info(" 6. [account_precipitate] ← 账号信息沉淀")
  180. logger.info(" 7. [output_persist ] ← 输出结构化 JSON")
  181. logger.info(" └─ Gate: OutputSchemaGate — 结构校验")
  182. logger.info("=" * 60)
  183. return {
  184. "trace_id": trace_id,
  185. "query": query,
  186. "demand_id": demand_id or "",
  187. "timeout_seconds": budget.timeout_seconds,
  188. "max_target_count": budget.max_target_count,
  189. "max_fallback_rounds": budget.max_fallback_rounds,
  190. "stages": [
  191. {"name": "demand_analysis", "label": "需求理解,产出搜索策略"},
  192. {"name": "content_search", "label": "按关键词召回候选文章", "gate": "SearchCompletenessGate"},
  193. {"name": "hard_filter", "label": "去重 + 基础规则过滤"},
  194. {"name": "coarse_filter", "label": "LLM 标题语义粗筛"},
  195. {"name": "quality_filter", "label": "数据指标评分 + LLM 正文精排", "gate": "FilterSufficiencyGate"},
  196. {"name": "account_precipitate", "label": "账号信息沉淀"},
  197. {"name": "output_persist", "label": "输出结构化 JSON", "gate": "OutputSchemaGate"},
  198. ],
  199. }
  200. # ─────────────────────────────────────────────
  201. # 4. Fallback Harness — 前置检查与降级路径
  202. # ─────────────────────────────────────────────
  203. def validate_prerequisites() -> None:
  204. """
  205. 前置条件检查(Harness 级别,不依赖 Core 内部检查)。
  206. 设计意图:
  207. - 把必须满足的约束提升到最外层,让失败快速、信息明确。
  208. - 避免在深层 Stage 里才触发 "OPEN_ROUTER_API_KEY 未设置"。
  209. """
  210. api_key = os.getenv("OPEN_ROUTER_API_KEY", "").strip()
  211. if not api_key:
  212. raise EnvironmentError(
  213. "缺少必要环境变量: OPEN_ROUTER_API_KEY\n"
  214. "请在 .env 文件或系统环境中设置该变量后重试。"
  215. )
  216. # ─────────────────────────────────────────────
  217. # 5. 主流程 — Harness 统一编排
  218. # ─────────────────────────────────────────────
  219. async def run_with_harness(
  220. query: str,
  221. demand_id: str,
  222. budget: AgentBudget,
  223. trace_id: str,
  224. use_db_policy: bool = True,
  225. run_plan: dict | None = None,
  226. ) -> RunSummary:
  227. """
  228. 带 Harness 的 Agent 执行入口。
  229. 职责分层:
  230. - 本函数只做"约束注入 + 超时包裹 + 摘要采集"。
  231. - 业务逻辑委托给 SearchAgentCore。
  232. - 不在这里写 if/else 业务判断。
  233. """
  234. start = time.monotonic()
  235. summary = RunSummary(success=False, query=query, demand_id=demand_id, trace_id=trace_id)
  236. # --- 策略来源标记(Observer 用) ---
  237. core = SearchAgentCore()
  238. policy_override: Optional[SearchAgentPolicy] = None
  239. if use_db_policy:
  240. try:
  241. # 预读策略仅用于确认 DB 连通性和标记来源;
  242. # SearchAgentCore.run() 内部会用同一 demand_id 再次加载。
  243. await core.load_policy(demand_id or None)
  244. summary.policy_source = "db"
  245. logger.info("策略已从 DB 加载: demand_id=%s", demand_id)
  246. except Exception as exc:
  247. logger.warning("DB 策略读取失败,降级为默认策略: %s", exc)
  248. policy_override = SearchAgentPolicy.defaults()
  249. summary.policy_source = "default(fallback)"
  250. else:
  251. policy_override = SearchAgentPolicy.defaults()
  252. summary.policy_source = "default"
  253. # --- 预算注入:target_count 不超过 max_target_count ---
  254. from src.pipeline.config.pipeline_config import RuntimePipelineConfig
  255. runtime = RuntimePipelineConfig.from_env()
  256. effective_target = min(runtime.target_count, budget.max_target_count)
  257. if effective_target != runtime.target_count:
  258. logger.info(
  259. "target_count 被 Budget Harness 限制: %d → %d",
  260. runtime.target_count,
  261. effective_target,
  262. )
  263. # --- 超时包裹执行 ---
  264. try:
  265. ctx = await asyncio.wait_for(
  266. core.run(
  267. query=query,
  268. demand_id=demand_id,
  269. target_count=effective_target,
  270. use_db_policy=(policy_override is None),
  271. policy_override=policy_override,
  272. trace_id=trace_id,
  273. run_plan=run_plan,
  274. ),
  275. timeout=budget.timeout_seconds,
  276. )
  277. except asyncio.TimeoutError:
  278. summary.elapsed_seconds = time.monotonic() - start
  279. summary.error_message = f"Agent 超时(>{budget.timeout_seconds}s),已中止"
  280. logger.error(summary.error_message)
  281. return summary
  282. except Exception as exc:
  283. summary.elapsed_seconds = time.monotonic() - start
  284. summary.error_message = str(exc)
  285. logger.exception("Agent 运行异常: %s", exc)
  286. return summary
  287. # --- 采集 Observer 摘要 ---
  288. summary.success = True
  289. summary.output_file = ctx.metadata.get("output_file", "")
  290. summary.candidate_count = len(ctx.candidate_articles)
  291. summary.filtered_count = len(ctx.filtered_articles)
  292. summary.account_count = len(ctx.accounts)
  293. summary.elapsed_seconds = time.monotonic() - start
  294. summary.stage_history = [
  295. {
  296. "stage_name": r.stage_name,
  297. "status": r.status,
  298. "attempt": r.attempt,
  299. }
  300. for r in ctx.stage_history
  301. ]
  302. return summary
  303. async def main() -> None:
  304. # ① 前置检查(Fallback Harness)
  305. validate_prerequisites()
  306. # ② 读取运行参数
  307. query = os.getenv("PIPELINE_QUERY", "伊朗以色列冲突、中老年人会关注什么?")
  308. demand_id = os.getenv("PIPELINE_DEMAND_ID", "1")
  309. # ③ 预算约束(Budget Harness)
  310. budget = AgentBudget.from_env()
  311. budget.validate()
  312. # ④ 生成全局 trace_id,贯穿整个运行周期
  313. trace_id = str(uuid4())
  314. logger.info("Trace ID: %s", trace_id)
  315. # ⑤ 运行计划(Planner Harness)
  316. run_plan = print_run_plan(query=query, demand_id=demand_id, budget=budget, trace_id=trace_id)
  317. # ⑥ 执行(带约束 + 观测)
  318. summary = await run_with_harness(
  319. query=query,
  320. demand_id=demand_id,
  321. budget=budget,
  322. trace_id=trace_id,
  323. use_db_policy=True,
  324. run_plan=run_plan,
  325. )
  326. # ⑦ 结构化输出摘要(Observer Harness)
  327. summary.log()
  328. # ⑧ 将全量日志移入 trace 目录
  329. global _file_handler, _tmp_log_path
  330. if _file_handler and _tmp_log_path and os.path.exists(_tmp_log_path):
  331. try:
  332. _file_handler.close()
  333. trace_dir = os.path.join("tests", "traces", trace_id)
  334. os.makedirs(trace_dir, exist_ok=True)
  335. dest = os.path.join(trace_dir, "full_log.log")
  336. shutil.move(_tmp_log_path, dest)
  337. logger.info("完整日志已保存: %s", dest)
  338. except Exception as exc:
  339. logger.warning("移动日志文件失败: %s", exc)
  340. # ⑨ 非零退出码(让 CI/调度系统能感知失败)
  341. if not summary.success:
  342. raise SystemExit(1)
  343. if __name__ == "__main__":
  344. asyncio.run(main())