run_pipeline.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. from __future__ import annotations
  2. import asyncio
  3. import logging
  4. import os
  5. import shutil
  6. import sys
  7. import tempfile
  8. from dotenv import load_dotenv
  9. from src.pipeline.runner import run_content_finder_from_cli
  10. load_dotenv()
  11. # ── 日志级别由环境变量控制,默认 DEBUG 全量捕获 ────────────
  12. _LOG_LEVEL = os.getenv("LOG_LEVEL", "DEBUG").upper()
  13. _CONSOLE_LEVEL = os.getenv("CONSOLE_LOG_LEVEL", "INFO").upper()
  14. _LOG_FMT = "%(asctime)s | %(levelname)-7s | %(name)s | %(message)s"
  15. _LOG_DATEFMT = "%Y-%m-%d %H:%M:%S"
  16. def _setup_logging(log_file_path: str) -> logging.FileHandler:
  17. """
  18. 配置双通道日志:console(INFO)+ file(DEBUG)。
  19. 不修改 agent 内核代码,通过 root logger 拦截所有子 logger 输出。
  20. """
  21. root = logging.getLogger()
  22. root.setLevel(getattr(logging, _LOG_LEVEL, logging.DEBUG))
  23. formatter = logging.Formatter(fmt=_LOG_FMT, datefmt=_LOG_DATEFMT)
  24. console = logging.StreamHandler(sys.__stdout__)
  25. console.setLevel(getattr(logging, _CONSOLE_LEVEL, logging.INFO))
  26. console.setFormatter(formatter)
  27. root.addHandler(console)
  28. fh = logging.FileHandler(log_file_path, mode="w", encoding="utf-8")
  29. fh.setLevel(logging.DEBUG)
  30. fh.setFormatter(formatter)
  31. root.addHandler(fh)
  32. for noisy in ("httpx", "httpcore", "urllib3", "asyncio"):
  33. logging.getLogger(noisy).setLevel(logging.WARNING)
  34. # agent 内核日志不写入全量日志文件(减少噪音)
  35. class _AgentLogFilter(logging.Filter):
  36. def filter(self, record: logging.LogRecord) -> bool:
  37. return not record.name.startswith("agent.")
  38. fh.addFilter(_AgentLogFilter())
  39. return fh
  40. logger = logging.getLogger(__name__)
  41. async def main() -> None:
  42. tmp = tempfile.NamedTemporaryFile(
  43. delete=False, suffix=".log", prefix="pipeline_run_", mode="w", encoding="utf-8",
  44. )
  45. tmp_path = tmp.name
  46. tmp.close()
  47. file_handler = _setup_logging(tmp_path)
  48. try:
  49. query = os.getenv("PIPELINE_QUERY", "伊朗、以色列、和平是永恒的主题")
  50. demand_id = os.getenv("PIPELINE_DEMAND_ID", "1")
  51. result = await run_content_finder_from_cli(query=query, demand_id=demand_id)
  52. logger.info("pipeline trace_id=%s", result.trace_id)
  53. logger.info("pipeline output=%s", result.metadata.get("output_file", ""))
  54. # 将日志文件移入 trace 目录
  55. file_handler.close()
  56. trace_dir = os.path.join("tests", "traces", result.trace_id)
  57. os.makedirs(trace_dir, exist_ok=True)
  58. dest = os.path.join(trace_dir, "full_log.log")
  59. shutil.move(tmp_path, dest)
  60. logger.info("完整日志已保存: %s", dest)
  61. finally:
  62. if os.path.exists(tmp_path):
  63. try:
  64. os.unlink(tmp_path)
  65. except OSError:
  66. pass
  67. if __name__ == "__main__":
  68. asyncio.run(main())