run_single.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. from typing import Dict, Any, Optional
  2. import os
  3. from pathlib import Path
  4. from agent import AgentRunner, RunConfig, FileSystemTraceStore, Trace, Message
  5. from agent.llm import create_openrouter_llm_call
  6. from agent.llm.prompts import SimplePrompt
  7. from agent.tools.builtin.knowledge import KnowledgeConfig
  8. # 默认搜索词
  9. DEFAULT_QUERY = "戏曲表演"
  10. DEFAULT_DEMAND_ID = 1
  11. import logging
  12. logger = logging.getLogger(__name__)
  13. PROJECT_ROOT = Path(__file__).resolve().parent
  14. async def run_agent(
  15. query: Optional[str] = None,
  16. demand_id: Optional[int] = None,
  17. stream_output: bool = True,
  18. ) -> Dict[str, Any]:
  19. """
  20. 执行 agent 任务
  21. Args:
  22. query: 查询内容(搜索词),None 则使用默认值
  23. demand_id: 本次搜索任务 id(int,关联 demand_content 表)
  24. stream_output: 是否流式输出到 stdout(run.py 需要,server.py 不需要)
  25. Returns:
  26. {
  27. "trace_id": "20260317_103046_xyz789",
  28. "status": "completed" | "failed",
  29. "error": "错误信息" # 失败时
  30. }
  31. """
  32. query = query or DEFAULT_QUERY
  33. demand_id = demand_id or DEFAULT_DEMAND_ID
  34. # 加载 prompt
  35. prompt_path = PROJECT_ROOT / "tests" / "content_finder.prompt"
  36. prompt = SimplePrompt(prompt_path)
  37. # output 目录
  38. output_dir = str(PROJECT_ROOT / "tests" / "output")
  39. # 构建消息(替换 %query%、%output_dir%、%demand_id%)
  40. demand_id_str = str(demand_id) if demand_id is not None else ""
  41. messages = prompt.build_messages(query=query, output_dir=output_dir, demand_id=demand_id_str)
  42. # 初始化配置
  43. api_key = "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"
  44. print(api_key)
  45. if not api_key:
  46. raise ValueError("OPEN_ROUTER_API_KEY 未设置")
  47. model_name = prompt.config.get("model", "sonnet-4.6")
  48. model = os.getenv("MODEL", f"anthropic/claude-{model_name}")
  49. temperature = float(prompt.config.get("temperature", 0.3))
  50. max_iterations = 30
  51. trace_dir = str(PROJECT_ROOT / "tests" / "traces")
  52. skills_dir = str(PROJECT_ROOT / "skills")
  53. Path(trace_dir).mkdir(parents=True, exist_ok=True)
  54. store = FileSystemTraceStore(base_path=trace_dir)
  55. allowed_tools = [
  56. "douyin_search",
  57. "douyin_user_videos",
  58. "get_content_fans_portrait",
  59. "get_account_fans_portrait",
  60. "store_results_mysql",
  61. "create_crawler_plan_by_douyin_content_id",
  62. "create_crawler_plan_by_douyin_account_id",
  63. ]
  64. runner = AgentRunner(
  65. llm_call=create_openrouter_llm_call(model=model),
  66. trace_store=store,
  67. skills_dir=skills_dir,
  68. )
  69. config = RunConfig(
  70. name="内容寻找",
  71. model=model,
  72. temperature=temperature,
  73. max_iterations=max_iterations,
  74. tools=allowed_tools,
  75. extra_llm_params={"max_tokens": 8192},
  76. knowledge=KnowledgeConfig(
  77. enable_extraction=True,
  78. enable_completion_extraction=True,
  79. enable_injection=True,
  80. owner="content_finder_agent",
  81. default_tags={"project": "content_finder"},
  82. default_scopes=["com.piaoquantv.supply"],
  83. default_search_types=["tool", "usecase", "definition"],
  84. default_search_owner="content_finder_agent"
  85. )
  86. )
  87. # 执行
  88. trace_id = None
  89. try:
  90. async for item in runner.run(messages=messages, config=config):
  91. if isinstance(item, Trace):
  92. trace_id = item.trace_id
  93. if item.status == "completed":
  94. logger.info(f"Agent 执行完成: trace_id={trace_id}")
  95. return {
  96. "trace_id": trace_id,
  97. "status": "completed"
  98. }
  99. elif item.status == "failed":
  100. logger.error(f"Agent 执行失败: {item.error_message}")
  101. return {
  102. "trace_id": trace_id,
  103. "status": "failed",
  104. "error": item.error_message
  105. }
  106. elif isinstance(item, Message) and stream_output:
  107. # 流式输出(仅 run.py 需要)
  108. if item.role == "assistant":
  109. content = item.content
  110. if isinstance(content, dict):
  111. text = content.get("text", "")
  112. tool_calls = content.get("tool_calls", [])
  113. if text:
  114. # 如果有推荐结果,完整输出
  115. if len(text) > 500 and ("推荐结果" in text or "推荐内容" in text or "🎯" in text):
  116. print(f"\n{text}")
  117. # 如果有工具调用且文本较短,只输出摘要
  118. elif tool_calls and len(text) > 100:
  119. print(f"[思考] {text[:100]}...")
  120. # 其他情况输出完整文本
  121. else:
  122. print(f"\n{text}")
  123. # 输出工具调用信息
  124. if tool_calls:
  125. for tc in tool_calls:
  126. tool_name = tc.get("function", {}).get("name", "unknown")
  127. # 跳过 goal 工具的输出,减少噪音
  128. if tool_name != "goal":
  129. print(f"[工具] {tool_name}")
  130. elif isinstance(content, str) and content:
  131. print(f"\n{content}")
  132. elif item.role == "tool":
  133. content = item.content
  134. if isinstance(content, dict):
  135. tool_name = content.get("tool_name", "unknown")
  136. print(f"[结果] {tool_name} ✓")
  137. # 如果循环结束但没有返回,说明异常退出
  138. return {
  139. "trace_id": trace_id,
  140. "status": "failed",
  141. "error": "Agent 异常退出"
  142. }
  143. except KeyboardInterrupt:
  144. logger.info("用户中断")
  145. if stream_output:
  146. print("\n用户中断")
  147. return {
  148. "trace_id": trace_id,
  149. "status": "failed",
  150. "error": "用户中断"
  151. }
  152. except Exception as e:
  153. logger.error(f"Agent 执行异常: {e}", exc_info=True)
  154. if stream_output:
  155. print(f"\n执行失败: {e}")
  156. return {
  157. "trace_id": trace_id,
  158. "status": "failed",
  159. "error": str(e)
  160. }
  161. if __name__ == "__main__":
  162. import asyncio
  163. asyncio.run(run_agent())