core.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. """
  2. 内容寻找 Agent - 核心执行逻辑
  3. 提供可复用的 agent 执行函数,供 run.py 和 server.py 调用。
  4. """
  5. import asyncio
  6. import logging
  7. import sys
  8. import os
  9. from pathlib import Path
  10. from typing import Optional, Dict, Any
  11. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  12. from dotenv import load_dotenv
  13. load_dotenv()
  14. from agent import (
  15. AgentRunner,
  16. RunConfig,
  17. FileSystemTraceStore,
  18. Trace,
  19. Message,
  20. )
  21. from agent.llm import create_openrouter_llm_call
  22. from agent.llm.prompts import SimplePrompt
  23. from agent.tools.builtin.knowledge import KnowledgeConfig
  24. # 导入工具(确保工具被注册)
  25. from tools import (
  26. douyin_search,
  27. douyin_user_videos,
  28. get_content_fans_portrait,
  29. get_account_fans_portrait,
  30. create_crawler_plan_by_douyin_content_id,
  31. create_crawler_plan_by_douyin_account_id,
  32. store_results_mysql,
  33. )
  34. logger = logging.getLogger(__name__)
  35. # 默认搜索词
  36. DEFAULT_QUERY = "毛泽东1965年深秋预言"
  37. DEFAULT_DEMAND_ID = 2629
  38. async def run_agent(
  39. query: Optional[str] = None,
  40. demand_id: Optional[int] = None,
  41. stream_output: bool = True,
  42. ) -> Dict[str, Any]:
  43. """
  44. 执行 agent 任务
  45. Args:
  46. query: 查询内容(搜索词),None 则使用默认值
  47. demand_id: 本次搜索任务 id(int,关联 demand_content 表)
  48. stream_output: 是否流式输出到 stdout(run.py 需要,server.py 不需要)
  49. Returns:
  50. {
  51. "trace_id": "20260317_103046_xyz789",
  52. "status": "completed" | "failed",
  53. "error": "错误信息" # 失败时
  54. }
  55. """
  56. query = query or DEFAULT_QUERY
  57. demand_id = demand_id or DEFAULT_DEMAND_ID
  58. # 加载 prompt
  59. prompt_path = Path(__file__).parent / "content_finder.prompt"
  60. prompt = SimplePrompt(prompt_path)
  61. # output 目录
  62. output_dir = os.getenv("OUTPUT_DIR", ".cache/output")
  63. # 构建消息(替换 %query%、%output_dir%、%demand_id%)
  64. demand_id_str = str(demand_id) if demand_id is not None else ""
  65. messages = prompt.build_messages(query=query, output_dir=output_dir, demand_id=demand_id_str)
  66. # 初始化配置
  67. api_key = os.getenv("OPEN_ROUTER_API_KEY")
  68. if not api_key:
  69. raise ValueError("OPEN_ROUTER_API_KEY 未设置")
  70. model_name = prompt.config.get("model", "sonnet-4.6")
  71. model = os.getenv("MODEL", f"anthropic/claude-{model_name}")
  72. temperature = float(prompt.config.get("temperature", 0.3))
  73. max_iterations = int(os.getenv("MAX_ITERATIONS", "30"))
  74. trace_dir = os.getenv("TRACE_DIR", ".cache/traces")
  75. skills_dir = str(Path(__file__).parent / "skills")
  76. Path(trace_dir).mkdir(parents=True, exist_ok=True)
  77. store = FileSystemTraceStore(base_path=trace_dir)
  78. allowed_tools = [
  79. "douyin_search",
  80. "douyin_user_videos",
  81. "get_content_fans_portrait",
  82. "get_account_fans_portrait",
  83. "store_results_mysql",
  84. "create_crawler_plan_by_douyin_content_id",
  85. "create_crawler_plan_by_douyin_account_id",
  86. ]
  87. runner = AgentRunner(
  88. llm_call=create_openrouter_llm_call(model=model),
  89. trace_store=store,
  90. skills_dir=skills_dir,
  91. )
  92. config = RunConfig(
  93. name="内容寻找",
  94. model=model,
  95. temperature=temperature,
  96. max_iterations=max_iterations,
  97. tools=allowed_tools,
  98. extra_llm_params={"max_tokens": 8192},
  99. knowledge=KnowledgeConfig(
  100. enable_extraction=True,
  101. enable_completion_extraction=True,
  102. enable_injection=True,
  103. owner="content_finder_agent",
  104. default_tags={"project": "content_finder"},
  105. default_scopes=["com.piaoquantv.supply"],
  106. default_search_types=["tool", "usecase", "definition"],
  107. default_search_owner="content_finder_agent"
  108. )
  109. )
  110. # 执行
  111. trace_id = None
  112. try:
  113. async for item in runner.run(messages=messages, config=config):
  114. if isinstance(item, Trace):
  115. trace_id = item.trace_id
  116. if item.status == "completed":
  117. logger.info(f"Agent 执行完成: trace_id={trace_id}")
  118. return {
  119. "trace_id": trace_id,
  120. "status": "completed"
  121. }
  122. elif item.status == "failed":
  123. logger.error(f"Agent 执行失败: {item.error_message}")
  124. return {
  125. "trace_id": trace_id,
  126. "status": "failed",
  127. "error": item.error_message
  128. }
  129. elif isinstance(item, Message) and stream_output:
  130. # 流式输出(仅 run.py 需要)
  131. if item.role == "assistant":
  132. content = item.content
  133. if isinstance(content, dict):
  134. text = content.get("text", "")
  135. tool_calls = content.get("tool_calls", [])
  136. if text:
  137. # 如果有推荐结果,完整输出
  138. if len(text) > 500 and ("推荐结果" in text or "推荐内容" in text or "🎯" in text):
  139. print(f"\n{text}")
  140. # 如果有工具调用且文本较短,只输出摘要
  141. elif tool_calls and len(text) > 100:
  142. print(f"[思考] {text[:100]}...")
  143. # 其他情况输出完整文本
  144. else:
  145. print(f"\n{text}")
  146. # 输出工具调用信息
  147. if tool_calls:
  148. for tc in tool_calls:
  149. tool_name = tc.get("function", {}).get("name", "unknown")
  150. # 跳过 goal 工具的输出,减少噪音
  151. if tool_name != "goal":
  152. print(f"[工具] {tool_name}")
  153. elif isinstance(content, str) and content:
  154. print(f"\n{content}")
  155. elif item.role == "tool":
  156. content = item.content
  157. if isinstance(content, dict):
  158. tool_name = content.get("tool_name", "unknown")
  159. print(f"[结果] {tool_name} ✓")
  160. # 如果循环结束但没有返回,说明异常退出
  161. return {
  162. "trace_id": trace_id,
  163. "status": "failed",
  164. "error": "Agent 异常退出"
  165. }
  166. except KeyboardInterrupt:
  167. logger.info("用户中断")
  168. if stream_output:
  169. print("\n用户中断")
  170. return {
  171. "trace_id": trace_id,
  172. "status": "failed",
  173. "error": "用户中断"
  174. }
  175. except Exception as e:
  176. logger.error(f"Agent 执行异常: {e}", exc_info=True)
  177. if stream_output:
  178. print(f"\n执行失败: {e}")
  179. return {
  180. "trace_id": trace_id,
  181. "status": "failed",
  182. "error": str(e)
  183. }