core.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. """
  2. 内容寻找 Agent - 核心执行逻辑
  3. 提供可复用的 agent 执行函数,供 run.py 和 server.py 调用。
  4. """
  5. import asyncio
  6. import logging
  7. import sys
  8. import os
  9. from pathlib import Path
  10. from typing import Optional, Dict, Any
  11. from utils.log_capture import attach_log_file, build_log, log
  12. from datetime import datetime
  13. from zoneinfo import ZoneInfo
  14. import uuid
  15. LOG_TZ = ZoneInfo("Asia/Shanghai")
  16. def _resolve_repo_root() -> Path:
  17. # /.../Agent/examples/content_finder/core.py -> repo root is /.../Agent
  18. return Path(__file__).resolve().parents[2]
  19. def _resolve_dir_from_env(repo_root: Path, raw: str) -> Path:
  20. p = Path(raw).expanduser()
  21. return p.resolve() if p.is_absolute() else (repo_root / p).resolve()
  22. def _resolve_log_file_path(
  23. *,
  24. content_finder_root: Path,
  25. output_dir_path: Path,
  26. trace_id: str | None,
  27. execution_id: str,
  28. ) -> Path:
  29. """
  30. 解析日志输出路径。
  31. 规则:
  32. - 如果设置了 INPUT_LOG_PATH:
  33. - 值为 OUTPUT_DIR / ${OUTPUT_DIR}:写入 OUTPUT_DIR/<trace_id>/log.txt
  34. - 绝对/相对路径:视为“目录”,写入 <dir>/run_log_<timestamp>.txt(兼容旧行为)
  35. - 未设置 INPUT_LOG_PATH:默认写入 OUTPUT_DIR/<trace_id>/log.txt
  36. """
  37. raw = (os.getenv("INPUT_LOG_PATH") or "").strip()
  38. dir_name = trace_id or execution_id
  39. if raw in {"OUTPUT_DIR", "${OUTPUT_DIR}"} or raw == "":
  40. return (output_dir_path / dir_name / "log.txt").resolve()
  41. p = Path(raw).expanduser()
  42. if not p.is_absolute():
  43. p = (content_finder_root / p).resolve()
  44. log_dir = p if not p.suffix else p.parent
  45. return (log_dir / f"run_log_{datetime.now(LOG_TZ).strftime('%Y%m%d_%H%M%S')}.txt").resolve()
  46. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  47. from dotenv import load_dotenv
  48. load_dotenv()
  49. # 保证从仓库根目录运行时也能读到 content_finder 下的 .env(INPUT_LOG_PATH 等)
  50. load_dotenv(dotenv_path=Path(__file__).resolve().parent / ".env", override=True)
  51. from agent import (
  52. AgentRunner,
  53. RunConfig,
  54. FileSystemTraceStore,
  55. Trace,
  56. Message,
  57. )
  58. from agent.llm import create_openrouter_llm_call
  59. from agent.llm.prompts import SimplePrompt
  60. from agent.tools.builtin.knowledge import KnowledgeConfig
  61. # 导入工具(确保工具被注册)
  62. from tools import (
  63. douyin_search,
  64. douyin_search_tikhub,
  65. douyin_user_videos,
  66. get_content_fans_portrait,
  67. get_account_fans_portrait,
  68. batch_fetch_portraits,
  69. create_crawler_plan_by_douyin_content_id,
  70. create_crawler_plan_by_douyin_account_id,
  71. store_results_mysql,
  72. think_and_plan,
  73. find_authors_from_db,
  74. get_video_topic,
  75. hot_topic_search,
  76. )
  77. logger = logging.getLogger(__name__)
  78. # 默认搜索词
  79. DEFAULT_QUERY = "财神,祝福语"
  80. DEFAULT_SUGGESTION = ""
  81. DEFAULT_DEMAND_ID = 19443
  82. def extract_assistant_text(message: Message) -> str:
  83. if message.role != "assistant":
  84. return ""
  85. content = message.content
  86. if isinstance(content, str):
  87. return content
  88. if isinstance(content, dict):
  89. text = content.get("text", "")
  90. # 即使本轮包含工具调用,也打印模型给出的文本,便于观察每一步输出
  91. if text:
  92. return text
  93. return ""
  94. async def run_agent(
  95. query: Optional[str] = None,
  96. demand_id: Optional[int] = None,
  97. suggestion: Optional[str] = None,
  98. stream_output: bool = True,
  99. log_assistant_text: bool = True,
  100. ) -> Dict[str, Any]:
  101. """
  102. 执行 agent 任务
  103. Args:
  104. query: 查询内容(搜索词),None 则使用默认值
  105. demand_id: 本次搜索任务 id(int,关联 demand_content 表)
  106. suggestion: 补充信息(与 query 同源时来自 demand_content.suggestion),None 则置空串参与占位符替换
  107. stream_output: 是否输出到 stdout(run.py 需要,server.py 不需要)
  108. log_assistant_text: 是否将 assistant 文本写入 log.txt(server 建议开启)
  109. Returns:
  110. {
  111. "trace_id": "20260317_103046_xyz789",
  112. "status": "completed" | "failed",
  113. "error": "错误信息" # 失败时
  114. }
  115. """
  116. query = query or DEFAULT_QUERY
  117. demand_id = demand_id or DEFAULT_DEMAND_ID
  118. suggestion_str = (suggestion or DEFAULT_SUGGESTION).strip()
  119. # 加载 prompt
  120. prompt_path = Path(__file__).parent / "content_finder.md"
  121. prompt = SimplePrompt(prompt_path)
  122. # output 目录(相对路径相对 content_finder)
  123. content_finder_root = Path(__file__).resolve().parent
  124. repo_root = _resolve_repo_root()
  125. output_dir = os.getenv("OUTPUT_DIR", ".cache/output")
  126. output_dir_path = _resolve_dir_from_env(repo_root, output_dir)
  127. # 构建消息(替换 %query%、%suggestion%、%output_dir%、%demand_id%)
  128. demand_id_str = str(demand_id) if demand_id is not None else ""
  129. messages = prompt.build_messages(
  130. query=query,
  131. suggestion=suggestion_str,
  132. output_dir=str(output_dir_path),
  133. demand_id=demand_id_str,
  134. )
  135. # 初始化配置
  136. api_key = os.getenv("OPEN_ROUTER_API_KEY")
  137. if not api_key:
  138. raise ValueError("OPEN_ROUTER_API_KEY 未设置")
  139. model_name = prompt.config.get("model", "sonnet-4.6")
  140. model = os.getenv("MODEL", f"anthropic/claude-{model_name}")
  141. temperature = float(prompt.config.get("temperature", 0.3))
  142. max_iterations = int(os.getenv("MAX_ITERATIONS", "30"))
  143. trace_dir = os.getenv("TRACE_DIR", ".cache/traces")
  144. skills_dir = str(Path(__file__).parent / "skills")
  145. trace_dir_path = _resolve_dir_from_env(repo_root, trace_dir)
  146. trace_dir_path.mkdir(parents=True, exist_ok=True)
  147. store = FileSystemTraceStore(base_path=str(trace_dir_path))
  148. allowed_tools = [
  149. "douyin_search",
  150. "douyin_search_tikhub",
  151. "douyin_user_videos",
  152. "batch_fetch_portraits",
  153. "find_authors_from_db",
  154. "store_results_mysql",
  155. "create_crawler_plan_by_douyin_content_id",
  156. "create_crawler_plan_by_douyin_account_id",
  157. "think_and_plan",
  158. "get_video_topic",
  159. "hot_topic_search",
  160. ]
  161. runner = AgentRunner(
  162. llm_call=create_openrouter_llm_call(model=model),
  163. trace_store=store,
  164. skills_dir=skills_dir,
  165. )
  166. config = RunConfig(
  167. name="内容寻找",
  168. model=model,
  169. temperature=temperature,
  170. enable_research_flow = False,
  171. goal_compression = "none",
  172. force_side_branch = None,
  173. max_iterations=max_iterations,
  174. tools=allowed_tools,
  175. extra_llm_params={"max_tokens": 8192},
  176. knowledge=KnowledgeConfig(
  177. enable_extraction=False,
  178. enable_completion_extraction=False,
  179. enable_injection=False,
  180. # owner="content_finder_agent",
  181. # default_tags={"project": "content_finder"},
  182. # default_scopes=["com.piaoquantv.supply"],
  183. # default_search_types=["tool", "usecase", "definition"],
  184. # default_search_owner="content_finder_agent"
  185. )
  186. )
  187. # 执行
  188. trace_id = None
  189. execution_id = str(uuid.uuid4())
  190. try:
  191. run_result: Optional[Dict[str, Any]] = None
  192. with build_log(execution_id) as log_buffer:
  193. async for item in runner.run(messages=messages, config=config):
  194. if isinstance(item, Trace):
  195. trace_id = item.trace_id
  196. # 一旦拿到 trace_id,立即绑定日志文件,确保后续步骤(含 exec_summary)能读到实时 log.txt
  197. try:
  198. log_file_path = _resolve_log_file_path(
  199. content_finder_root=content_finder_root,
  200. output_dir_path=output_dir_path,
  201. trace_id=trace_id,
  202. execution_id=execution_id,
  203. )
  204. attach_log_file(execution_id, log_file_path)
  205. except Exception as e:
  206. logger.warning(f"绑定实时 log.txt 失败: trace_id={trace_id}, err={e}")
  207. if item.status == "completed":
  208. logger.info(f"Agent 执行完成: trace_id={trace_id}")
  209. run_result = {
  210. "trace_id": trace_id,
  211. "status": "completed",
  212. }
  213. break
  214. if item.status == "failed":
  215. logger.error(f"Agent 执行失败: {item.error_message}")
  216. run_result = {
  217. "trace_id": trace_id,
  218. "status": "failed",
  219. "error": item.error_message,
  220. }
  221. break
  222. elif isinstance(item, Message):
  223. text = extract_assistant_text(item)
  224. if text and log_assistant_text:
  225. log(f"[assistant] {text}")
  226. if text and stream_output:
  227. print(text)
  228. if run_result is None:
  229. run_result = {
  230. "trace_id": trace_id,
  231. "status": "failed",
  232. "error": "Agent 异常退出",
  233. }
  234. full_log = log_buffer.getvalue()
  235. log_file_path = _resolve_log_file_path(
  236. content_finder_root=content_finder_root,
  237. output_dir_path=output_dir_path,
  238. trace_id=trace_id,
  239. execution_id=execution_id,
  240. )
  241. # 兜底:如果实时落盘失败/未绑定,则在结束时一次性写入
  242. try:
  243. if not log_file_path.exists():
  244. log_file_path.parent.mkdir(parents=True, exist_ok=True)
  245. with open(log_file_path, "w", encoding="utf-8") as f:
  246. f.write(full_log)
  247. except Exception as e:
  248. logger.warning(f"写入 log.txt 兜底失败: trace_id={trace_id}, err={e}")
  249. try:
  250. from render_log_html import render_log_html_and_upload
  251. if trace_id:
  252. url = render_log_html_and_upload(trace_id=trace_id, log_file_path=log_file_path)
  253. if url:
  254. logger.info(f"log.html 已上传: trace_id={trace_id}, url={url}")
  255. except Exception as e:
  256. logger.warning(f"渲染/上传 log.html 失败: trace_id={trace_id}, err={e}")
  257. return run_result
  258. except KeyboardInterrupt:
  259. logger.info("用户中断")
  260. if stream_output:
  261. print("\n用户中断")
  262. return {
  263. "trace_id": trace_id,
  264. "status": "failed",
  265. "error": "用户中断"
  266. }
  267. except Exception as e:
  268. logger.error(f"Agent 执行异常: {e}", exc_info=True)
  269. if stream_output:
  270. print(f"\n执行失败: {e}")
  271. return {
  272. "trace_id": trace_id,
  273. "status": "failed",
  274. "error": str(e)
  275. }