|
|
@@ -0,0 +1,191 @@
|
|
|
+from typing import Dict, Any, Optional
|
|
|
+import os
|
|
|
+from pathlib import Path
|
|
|
+from agent import AgentRunner, RunConfig, FileSystemTraceStore, Trace, Message
|
|
|
+from agent.llm import create_openrouter_llm_call
|
|
|
+from agent.llm.prompts import SimplePrompt
|
|
|
+from agent.tools.builtin.knowledge import KnowledgeConfig
|
|
|
+
|
|
|
+# 默认搜索词
|
|
|
+DEFAULT_QUERY = "戏曲表演"
|
|
|
+DEFAULT_DEMAND_ID = 1
|
|
|
+
|
|
|
+import logging
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+
|
|
|
+async def run_agent(
|
|
|
+ query: Optional[str] = None,
|
|
|
+ demand_id: Optional[int] = None,
|
|
|
+ stream_output: bool = True,
|
|
|
+) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 执行 agent 任务
|
|
|
+
|
|
|
+ Args:
|
|
|
+ query: 查询内容(搜索词),None 则使用默认值
|
|
|
+ demand_id: 本次搜索任务 id(int,关联 demand_content 表)
|
|
|
+ stream_output: 是否流式输出到 stdout(run.py 需要,server.py 不需要)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ {
|
|
|
+ "trace_id": "20260317_103046_xyz789",
|
|
|
+ "status": "completed" | "failed",
|
|
|
+ "error": "错误信息" # 失败时
|
|
|
+ }
|
|
|
+ """
|
|
|
+ query = query or DEFAULT_QUERY
|
|
|
+ demand_id = demand_id or DEFAULT_DEMAND_ID
|
|
|
+
|
|
|
+ # 加载 prompt
|
|
|
+ prompt_path = "content_finder.prompt"
|
|
|
+ prompt = SimplePrompt(prompt_path)
|
|
|
+
|
|
|
+ # output 目录
|
|
|
+ output_dir = "output"
|
|
|
+
|
|
|
+ # 构建消息(替换 %query%、%output_dir%、%demand_id%)
|
|
|
+ demand_id_str = str(demand_id) if demand_id is not None else ""
|
|
|
+ messages = prompt.build_messages(query=query, output_dir=output_dir, demand_id=demand_id_str)
|
|
|
+
|
|
|
+ # 初始化配置
|
|
|
+ api_key = os.getenv("OPEN_ROUTER_API_KEY")
|
|
|
+ if not api_key:
|
|
|
+ raise ValueError("OPEN_ROUTER_API_KEY 未设置")
|
|
|
+
|
|
|
+ model_name = prompt.config.get("model", "sonnet-4.6")
|
|
|
+ model = os.getenv("MODEL", f"anthropic/claude-{model_name}")
|
|
|
+ temperature = float(prompt.config.get("temperature", 0.3))
|
|
|
+ max_iterations = 30
|
|
|
+ trace_dir = "traces"
|
|
|
+
|
|
|
+ skills_dir = str(Path(__file__).parent / "skills")
|
|
|
+
|
|
|
+ Path(trace_dir).mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ store = FileSystemTraceStore(base_path=trace_dir)
|
|
|
+
|
|
|
+ allowed_tools = [
|
|
|
+ "douyin_search",
|
|
|
+ "douyin_user_videos",
|
|
|
+ "get_content_fans_portrait",
|
|
|
+ "get_account_fans_portrait",
|
|
|
+ "store_results_mysql",
|
|
|
+ "create_crawler_plan_by_douyin_content_id",
|
|
|
+ "create_crawler_plan_by_douyin_account_id",
|
|
|
+ ]
|
|
|
+
|
|
|
+ runner = AgentRunner(
|
|
|
+ llm_call=create_openrouter_llm_call(model=model),
|
|
|
+ trace_store=store,
|
|
|
+ skills_dir=skills_dir,
|
|
|
+ )
|
|
|
+
|
|
|
+ config = RunConfig(
|
|
|
+ name="内容寻找",
|
|
|
+ model=model,
|
|
|
+ temperature=temperature,
|
|
|
+ max_iterations=max_iterations,
|
|
|
+ tools=allowed_tools,
|
|
|
+ extra_llm_params={"max_tokens": 8192},
|
|
|
+ knowledge=KnowledgeConfig(
|
|
|
+ enable_extraction=True,
|
|
|
+ enable_completion_extraction=True,
|
|
|
+ enable_injection=True,
|
|
|
+ owner="content_finder_agent",
|
|
|
+ default_tags={"project": "content_finder"},
|
|
|
+ default_scopes=["com.piaoquantv.supply"],
|
|
|
+ default_search_types=["tool", "usecase", "definition"],
|
|
|
+ default_search_owner="content_finder_agent"
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ # 执行
|
|
|
+ trace_id = None
|
|
|
+
|
|
|
+ try:
|
|
|
+ async for item in runner.run(messages=messages, config=config):
|
|
|
+ if isinstance(item, Trace):
|
|
|
+ trace_id = item.trace_id
|
|
|
+
|
|
|
+ if item.status == "completed":
|
|
|
+ logger.info(f"Agent 执行完成: trace_id={trace_id}")
|
|
|
+ return {
|
|
|
+ "trace_id": trace_id,
|
|
|
+ "status": "completed"
|
|
|
+ }
|
|
|
+ elif item.status == "failed":
|
|
|
+ logger.error(f"Agent 执行失败: {item.error_message}")
|
|
|
+ return {
|
|
|
+ "trace_id": trace_id,
|
|
|
+ "status": "failed",
|
|
|
+ "error": item.error_message
|
|
|
+ }
|
|
|
+
|
|
|
+ elif isinstance(item, Message) and stream_output:
|
|
|
+ # 流式输出(仅 run.py 需要)
|
|
|
+ if item.role == "assistant":
|
|
|
+ content = item.content
|
|
|
+ if isinstance(content, dict):
|
|
|
+ text = content.get("text", "")
|
|
|
+ tool_calls = content.get("tool_calls", [])
|
|
|
+
|
|
|
+ if text:
|
|
|
+ # 如果有推荐结果,完整输出
|
|
|
+ if len(text) > 500 and ("推荐结果" in text or "推荐内容" in text or "🎯" in text):
|
|
|
+ print(f"\n{text}")
|
|
|
+ # 如果有工具调用且文本较短,只输出摘要
|
|
|
+ elif tool_calls and len(text) > 100:
|
|
|
+ print(f"[思考] {text[:100]}...")
|
|
|
+ # 其他情况输出完整文本
|
|
|
+ else:
|
|
|
+ print(f"\n{text}")
|
|
|
+
|
|
|
+ # 输出工具调用信息
|
|
|
+ if tool_calls:
|
|
|
+ for tc in tool_calls:
|
|
|
+ tool_name = tc.get("function", {}).get("name", "unknown")
|
|
|
+ # 跳过 goal 工具的输出,减少噪音
|
|
|
+ if tool_name != "goal":
|
|
|
+ print(f"[工具] {tool_name}")
|
|
|
+ elif isinstance(content, str) and content:
|
|
|
+ print(f"\n{content}")
|
|
|
+
|
|
|
+ elif item.role == "tool":
|
|
|
+ content = item.content
|
|
|
+ if isinstance(content, dict):
|
|
|
+ tool_name = content.get("tool_name", "unknown")
|
|
|
+ print(f"[结果] {tool_name} ✓")
|
|
|
+
|
|
|
+ # 如果循环结束但没有返回,说明异常退出
|
|
|
+ return {
|
|
|
+ "trace_id": trace_id,
|
|
|
+ "status": "failed",
|
|
|
+ "error": "Agent 异常退出"
|
|
|
+ }
|
|
|
+
|
|
|
+ except KeyboardInterrupt:
|
|
|
+ logger.info("用户中断")
|
|
|
+ if stream_output:
|
|
|
+ print("\n用户中断")
|
|
|
+ return {
|
|
|
+ "trace_id": trace_id,
|
|
|
+ "status": "failed",
|
|
|
+ "error": "用户中断"
|
|
|
+ }
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Agent 执行异常: {e}", exc_info=True)
|
|
|
+ if stream_output:
|
|
|
+ print(f"\n执行失败: {e}")
|
|
|
+ return {
|
|
|
+ "trace_id": trace_id,
|
|
|
+ "status": "failed",
|
|
|
+ "error": str(e)
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ import asyncio
|
|
|
+ asyncio.run(run_agent())
|
|
|
+
|