core.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. """
  2. 内容寻找 Agent - 核心执行逻辑
  3. 提供可复用的 agent 执行函数,供 run.py 和 server.py 调用。
  4. """
  5. import asyncio
  6. import logging
  7. import sys
  8. import os
  9. from pathlib import Path
  10. from typing import Optional, Dict, Any
  11. from utils.log_capture import build_log, log
  12. from datetime import datetime
  13. import uuid
  14. def _resolve_input_log_dir(content_finder_root: Path) -> Path:
  15. """与 .env 中 INPUT_LOG_PATH 一致:目录;相对路径相对 content_finder 根目录。"""
  16. raw = os.getenv("INPUT_LOG_PATH", ".cache/input_log")
  17. p = Path(raw).expanduser()
  18. if p.is_absolute():
  19. return p if not p.suffix else p.parent
  20. return (content_finder_root / p).resolve()
  21. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  22. from dotenv import load_dotenv
  23. load_dotenv()
  24. # 保证从仓库根目录运行时也能读到 content_finder 下的 .env(INPUT_LOG_PATH 等)
  25. load_dotenv(dotenv_path=Path(__file__).resolve().parent / ".env", override=True)
  26. from agent import (
  27. AgentRunner,
  28. RunConfig,
  29. FileSystemTraceStore,
  30. Trace,
  31. Message,
  32. )
  33. from agent.llm import create_openrouter_llm_call
  34. from agent.llm.prompts import SimplePrompt
  35. from agent.tools.builtin.knowledge import KnowledgeConfig
  36. # 导入工具(确保工具被注册)
  37. from tools import (
  38. douyin_search,
  39. douyin_search_fallback,
  40. douyin_user_videos,
  41. get_content_fans_portrait,
  42. get_account_fans_portrait,
  43. create_crawler_plan_by_douyin_content_id,
  44. create_crawler_plan_by_douyin_account_id,
  45. store_results_mysql,
  46. think_and_plan,
  47. find_authors_from_db,
  48. get_video_topic,
  49. )
  50. logger = logging.getLogger(__name__)
  51. # 默认搜索词
  52. DEFAULT_QUERY = "毛泽东,反腐倡廉"
  53. DEFAULT_DEMAND_ID = 1
  54. def extract_assistant_text(message: Message) -> str:
  55. if message.role != "assistant":
  56. return ""
  57. content = message.content
  58. if isinstance(content, str):
  59. return content
  60. if isinstance(content, dict):
  61. text = content.get("text", "")
  62. # 即使本轮包含工具调用,也打印模型给出的文本,便于观察每一步输出
  63. if text:
  64. return text
  65. return ""
  66. async def run_agent(
  67. query: Optional[str] = None,
  68. demand_id: Optional[int] = None,
  69. stream_output: bool = True,
  70. ) -> Dict[str, Any]:
  71. """
  72. 执行 agent 任务
  73. Args:
  74. query: 查询内容(搜索词),None 则使用默认值
  75. demand_id: 本次搜索任务 id(int,关联 demand_content 表)
  76. stream_output: 是否流式输出到 stdout(run.py 需要,server.py 不需要)
  77. Returns:
  78. {
  79. "trace_id": "20260317_103046_xyz789",
  80. "status": "completed" | "failed",
  81. "error": "错误信息" # 失败时
  82. }
  83. """
  84. query = query or DEFAULT_QUERY
  85. demand_id = demand_id or DEFAULT_DEMAND_ID
  86. # 加载 prompt
  87. prompt_path = Path(__file__).parent / "content_finder.md"
  88. prompt = SimplePrompt(prompt_path)
  89. # output 目录(相对路径相对 content_finder)
  90. content_finder_root = Path(__file__).resolve().parent
  91. output_dir = os.getenv("OUTPUT_DIR", ".cache/output")
  92. output_dir_path = Path(output_dir).expanduser()
  93. if not output_dir_path.is_absolute():
  94. output_dir_path = (content_finder_root / output_dir_path).resolve()
  95. # 构建消息(替换 %query%、%output_dir%、%demand_id%)
  96. demand_id_str = str(demand_id) if demand_id is not None else ""
  97. messages = prompt.build_messages(
  98. query=query, output_dir=str(output_dir_path), demand_id=demand_id_str
  99. )
  100. # 初始化配置
  101. api_key = os.getenv("OPEN_ROUTER_API_KEY")
  102. if not api_key:
  103. raise ValueError("OPEN_ROUTER_API_KEY 未设置")
  104. model_name = prompt.config.get("model", "sonnet-4.6")
  105. model = os.getenv("MODEL", f"anthropic/claude-{model_name}")
  106. temperature = float(prompt.config.get("temperature", 0.3))
  107. max_iterations = int(os.getenv("MAX_ITERATIONS", "30"))
  108. trace_dir = os.getenv("TRACE_DIR", ".cache/traces")
  109. skills_dir = str(Path(__file__).parent / "skills")
  110. Path(trace_dir).mkdir(parents=True, exist_ok=True)
  111. store = FileSystemTraceStore(base_path=trace_dir)
  112. allowed_tools = [
  113. "douyin_search",
  114. "douyin_search_fallback",
  115. "douyin_user_videos",
  116. "get_content_fans_portrait",
  117. "get_account_fans_portrait",
  118. "find_authors_from_db",
  119. "store_results_mysql",
  120. "create_crawler_plan_by_douyin_content_id",
  121. "create_crawler_plan_by_douyin_account_id",
  122. "think_and_plan",
  123. "get_video_topic",
  124. ]
  125. runner = AgentRunner(
  126. llm_call=create_openrouter_llm_call(model=model),
  127. trace_store=store,
  128. skills_dir=skills_dir,
  129. )
  130. config = RunConfig(
  131. name="内容寻找",
  132. model=model,
  133. temperature=temperature,
  134. enable_research_flow = False,
  135. goal_compression = "none",
  136. force_side_branch = None,
  137. max_iterations=max_iterations,
  138. tools=allowed_tools,
  139. extra_llm_params={"max_tokens": 8192},
  140. knowledge=KnowledgeConfig(
  141. enable_extraction=False,
  142. enable_completion_extraction=False,
  143. enable_injection=False,
  144. # owner="content_finder_agent",
  145. # default_tags={"project": "content_finder"},
  146. # default_scopes=["com.piaoquantv.supply"],
  147. # default_search_types=["tool", "usecase", "definition"],
  148. # default_search_owner="content_finder_agent"
  149. )
  150. )
  151. # 执行
  152. trace_id = None
  153. execution_id = str(uuid.uuid4())
  154. try:
  155. log_dir = _resolve_input_log_dir(content_finder_root)
  156. log_dir.mkdir(parents=True, exist_ok=True)
  157. log_file_path = log_dir / f"run_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
  158. run_result: Optional[Dict[str, Any]] = None
  159. with build_log(execution_id) as log_buffer:
  160. async for item in runner.run(messages=messages, config=config):
  161. if isinstance(item, Trace):
  162. trace_id = item.trace_id
  163. if item.status == "completed":
  164. logger.info(f"Agent 执行完成: trace_id={trace_id}")
  165. run_result = {
  166. "trace_id": trace_id,
  167. "status": "completed",
  168. }
  169. break
  170. if item.status == "failed":
  171. logger.error(f"Agent 执行失败: {item.error_message}")
  172. run_result = {
  173. "trace_id": trace_id,
  174. "status": "failed",
  175. "error": item.error_message,
  176. }
  177. break
  178. elif isinstance(item, Message) and stream_output:
  179. text = extract_assistant_text(item)
  180. if text:
  181. log(f"[assistant] {text}")
  182. if run_result is None:
  183. run_result = {
  184. "trace_id": trace_id,
  185. "status": "failed",
  186. "error": "Agent 异常退出",
  187. }
  188. full_log = log_buffer.getvalue()
  189. with open(log_file_path, "w", encoding="utf-8") as f:
  190. f.write(full_log)
  191. return run_result
  192. except KeyboardInterrupt:
  193. logger.info("用户中断")
  194. if stream_output:
  195. print("\n用户中断")
  196. return {
  197. "trace_id": trace_id,
  198. "status": "failed",
  199. "error": "用户中断"
  200. }
  201. except Exception as e:
  202. logger.error(f"Agent 执行异常: {e}", exc_info=True)
  203. if stream_output:
  204. print(f"\n执行失败: {e}")
  205. return {
  206. "trace_id": trace_id,
  207. "status": "failed",
  208. "error": str(e)
  209. }