瀏覽代碼

标准化服务

luojunhui 10 小時之前
父節點
當前提交
93131dc7b7
共有 1 個文件被更改,包括 191 次插入0 次删除
  1. 191 0
      tests/run_single.py

+ 191 - 0
tests/run_single.py

@@ -0,0 +1,191 @@
+from typing import Dict, Any, Optional
+import os
+from pathlib import Path
+from agent import AgentRunner, RunConfig, FileSystemTraceStore, Trace, Message
+from agent.llm import create_openrouter_llm_call
+from agent.llm.prompts import SimplePrompt
+from agent.tools.builtin.knowledge import KnowledgeConfig
+
+# 默认搜索词
+DEFAULT_QUERY = "戏曲表演"
+DEFAULT_DEMAND_ID = 1
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+async def run_agent(
+        query: Optional[str] = None,
+        demand_id: Optional[int] = None,
+        stream_output: bool = True,
+) -> Dict[str, Any]:
+    """
+    执行 agent 任务
+
+    Args:
+        query: 查询内容(搜索词),None 则使用默认值
+        demand_id: 本次搜索任务 id(int,关联 demand_content 表)
+        stream_output: 是否流式输出到 stdout(run.py 需要,server.py 不需要)
+
+    Returns:
+        {
+            "trace_id": "20260317_103046_xyz789",
+            "status": "completed" | "failed",
+            "error": "错误信息"  # 失败时
+        }
+    """
+    query = query or DEFAULT_QUERY
+    demand_id = demand_id or DEFAULT_DEMAND_ID
+
+    # 加载 prompt
+    prompt_path = "content_finder.prompt"
+    prompt = SimplePrompt(prompt_path)
+
+    # output 目录
+    output_dir = "output"
+
+    # 构建消息(替换 %query%、%output_dir%、%demand_id%)
+    demand_id_str = str(demand_id) if demand_id is not None else ""
+    messages = prompt.build_messages(query=query, output_dir=output_dir, demand_id=demand_id_str)
+
+    # 初始化配置
+    api_key = os.getenv("OPEN_ROUTER_API_KEY")
+    if not api_key:
+        raise ValueError("OPEN_ROUTER_API_KEY 未设置")
+
+    model_name = prompt.config.get("model", "sonnet-4.6")
+    model = os.getenv("MODEL", f"anthropic/claude-{model_name}")
+    temperature = float(prompt.config.get("temperature", 0.3))
+    max_iterations = 30
+    trace_dir = "traces"
+
+    skills_dir = str(Path(__file__).parent / "skills")
+
+    Path(trace_dir).mkdir(parents=True, exist_ok=True)
+
+    store = FileSystemTraceStore(base_path=trace_dir)
+
+    allowed_tools = [
+        "douyin_search",
+        "douyin_user_videos",
+        "get_content_fans_portrait",
+        "get_account_fans_portrait",
+        "store_results_mysql",
+        "create_crawler_plan_by_douyin_content_id",
+        "create_crawler_plan_by_douyin_account_id",
+    ]
+
+    runner = AgentRunner(
+        llm_call=create_openrouter_llm_call(model=model),
+        trace_store=store,
+        skills_dir=skills_dir,
+    )
+
+    config = RunConfig(
+        name="内容寻找",
+        model=model,
+        temperature=temperature,
+        max_iterations=max_iterations,
+        tools=allowed_tools,
+        extra_llm_params={"max_tokens": 8192},
+        knowledge=KnowledgeConfig(
+            enable_extraction=True,
+            enable_completion_extraction=True,
+            enable_injection=True,
+            owner="content_finder_agent",
+            default_tags={"project": "content_finder"},
+            default_scopes=["com.piaoquantv.supply"],
+            default_search_types=["tool", "usecase", "definition"],
+            default_search_owner="content_finder_agent"
+        )
+    )
+
+    # 执行
+    trace_id = None
+
+    try:
+        async for item in runner.run(messages=messages, config=config):
+            if isinstance(item, Trace):
+                trace_id = item.trace_id
+
+                if item.status == "completed":
+                    logger.info(f"Agent 执行完成: trace_id={trace_id}")
+                    return {
+                        "trace_id": trace_id,
+                        "status": "completed"
+                    }
+                elif item.status == "failed":
+                    logger.error(f"Agent 执行失败: {item.error_message}")
+                    return {
+                        "trace_id": trace_id,
+                        "status": "failed",
+                        "error": item.error_message
+                    }
+
+            elif isinstance(item, Message) and stream_output:
+                # 流式输出(仅 run.py 需要)
+                if item.role == "assistant":
+                    content = item.content
+                    if isinstance(content, dict):
+                        text = content.get("text", "")
+                        tool_calls = content.get("tool_calls", [])
+
+                        if text:
+                            # 如果有推荐结果,完整输出
+                            if len(text) > 500 and ("推荐结果" in text or "推荐内容" in text or "🎯" in text):
+                                print(f"\n{text}")
+                            # 如果有工具调用且文本较短,只输出摘要
+                            elif tool_calls and len(text) > 100:
+                                print(f"[思考] {text[:100]}...")
+                            # 其他情况输出完整文本
+                            else:
+                                print(f"\n{text}")
+
+                        # 输出工具调用信息
+                        if tool_calls:
+                            for tc in tool_calls:
+                                tool_name = tc.get("function", {}).get("name", "unknown")
+                                # 跳过 goal 工具的输出,减少噪音
+                                if tool_name != "goal":
+                                    print(f"[工具] {tool_name}")
+                    elif isinstance(content, str) and content:
+                        print(f"\n{content}")
+
+                elif item.role == "tool":
+                    content = item.content
+                    if isinstance(content, dict):
+                        tool_name = content.get("tool_name", "unknown")
+                        print(f"[结果] {tool_name} ✓")
+
+        # 如果循环结束但没有返回,说明异常退出
+        return {
+            "trace_id": trace_id,
+            "status": "failed",
+            "error": "Agent 异常退出"
+        }
+
+    except KeyboardInterrupt:
+        logger.info("用户中断")
+        if stream_output:
+            print("\n用户中断")
+        return {
+            "trace_id": trace_id,
+            "status": "failed",
+            "error": "用户中断"
+        }
+    except Exception as e:
+        logger.error(f"Agent 执行异常: {e}", exc_info=True)
+        if stream_output:
+            print(f"\n执行失败: {e}")
+        return {
+            "trace_id": trace_id,
+            "status": "failed",
+            "error": str(e)
+        }
+
+
+if __name__ == "__main__":
+    import asyncio
+    asyncio.run(run_agent())
+