""" 内容寻找服务 提供: 1. API 接口:POST /api/tasks - 触发内容寻找任务 2. 定时调度:每 10 分钟调用外部 API 获取 query 并执行任务 3. 并发控制:限制最大并发任务数 """ import asyncio import logging import os from datetime import datetime from pathlib import Path from typing import Optional import sys sys.path.insert(0, str(Path(__file__).parent.parent.parent)) import httpx from fastapi import FastAPI, HTTPException from pydantic import BaseModel from apscheduler.schedulers.asyncio import AsyncIOScheduler from dotenv import load_dotenv load_dotenv() import core # 配置日志 log_dir = Path(__file__).parent / '.cache' log_dir.mkdir(exist_ok=True) logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(log_dir / 'server.log'), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) # FastAPI 应用 app = FastAPI( title="内容寻找服务", version="1.0.0", description="抖音内容寻找 Agent 服务" ) # 定时调度器 scheduler = AsyncIOScheduler() # 并发控制 MAX_CONCURRENT_TASKS = int(os.getenv("MAX_CONCURRENT_TASKS", "3")) task_semaphore = asyncio.Semaphore(MAX_CONCURRENT_TASKS) # 统计信息 stats = { "total_tasks": 0, "completed_tasks": 0, "failed_tasks": 0, "scheduled_tasks": 0 } # ============ 数据模型 ============ class TaskRequest(BaseModel): query: Optional[str] = None class TaskResponse(BaseModel): trace_id: str status: str query: str message: str # ============ 核心函数 ============ async def execute_task(query: str, task_type: str = "api"): """ 执行任务(带并发控制) Args: query: 查询内容 task_type: 任务类型("api" 或 "scheduled") """ async with task_semaphore: current_concurrent = MAX_CONCURRENT_TASKS - task_semaphore._value + 1 logger.info(f"任务开始 [{task_type}]: query={query[:50]}..., 当前并发={current_concurrent}/{MAX_CONCURRENT_TASKS}") start_time = datetime.now() stats["total_tasks"] += 1 if task_type == "scheduled": stats["scheduled_tasks"] += 1 try: # 执行 agent(不流式输出) result = await core.run_agent(query, stream_output=False) duration = (datetime.now() - start_time).total_seconds() if result["status"] == "completed": stats["completed_tasks"] += 1 logger.info(f"任务完成 [{task_type}]: trace_id={result['trace_id']}, 耗时={duration:.1f}s") else: stats["failed_tasks"] += 1 logger.error(f"任务失败 [{task_type}]: trace_id={result.get('trace_id')}, 错误={result.get('error')}, 耗时={duration:.1f}s") except Exception as e: stats["failed_tasks"] += 1 duration = (datetime.now() - start_time).total_seconds() logger.error(f"任务异常 [{task_type}]: {e}, 耗时={duration:.1f}s", exc_info=True) async def scheduled_task(): """ 定时任务:每 10 分钟执行一次 流程: 1. 调用外部 API 获取 query 2. 如果成功,执行任务 3. 如果失败,跳过本次执行(不使用兜底) """ logger.info("定时任务触发") try: # 1. 调用外部 API 获取 query query_api = os.getenv("SCHEDULE_QUERY_API") if not query_api: logger.warning("未配置 SCHEDULE_QUERY_API,跳过定时任务") return timeout = float(os.getenv("SCHEDULE_QUERY_API_TIMEOUT", "10.0")) async with httpx.AsyncClient() as client: headers = {} logger.info(f"调用外部 API: {query_api}") response = await client.get( query_api, headers=headers, timeout=timeout ) response.raise_for_status() data = response.json() # 2. 提取 query query = data.get("query") if not query: logger.info("定时任务跳过:外部 API 返回的 query 为空") return # 3. 执行任务 logger.info(f"定时任务启动: query={query[:50]}...") asyncio.create_task(execute_task(query, task_type="scheduled")) except httpx.HTTPStatusError as e: logger.error(f"定时任务失败:外部 API 返回错误 {e.response.status_code}: {e.response.text}") except httpx.RequestError as e: logger.error(f"定时任务失败:外部 API 请求失败: {e}") except httpx.TimeoutException: logger.error(f"定时任务失败:外部 API 请求超时") except Exception as e: logger.error(f"定时任务失败:未知错误: {e}", exc_info=True) # ============ API 接口 ============ @app.post("/api/tasks", response_model=TaskResponse) async def create_task(request: TaskRequest): """ 创建内容寻找任务 Args: request.query: 查询内容(可选,不传则使用默认值) Returns: { "trace_id": "20260317_103046_xyz789", "status": "started", "query": "...", "message": "任务已启动,结果将保存到 .cache/traces/xxx/" } """ # 获取 query query = request.query or core.DEFAULT_QUERY # 用 Event 等待 trace_id trace_id_ready = asyncio.Event() trace_id_holder = {"id": None} async def run_and_capture(): try: # 获取第一个 Trace 对象来获取 trace_id from agent import Trace async with task_semaphore: # 重新构建 runner 来获取 trace_id from agent import AgentRunner, RunConfig, FileSystemTraceStore from agent.llm import create_openrouter_llm_call from agent.llm.prompts import SimplePrompt from agent.tools.builtin.knowledge import KnowledgeConfig prompt_path = Path(__file__).parent / "content_finder.prompt" prompt = SimplePrompt(prompt_path) messages = prompt.build_messages(query=query) api_key = os.getenv("OPEN_ROUTER_API_KEY") model_name = prompt.config.get("model", "sonnet-4.6") model = os.getenv("MODEL", f"anthropic/claude-{model_name}") temperature = float(prompt.config.get("temperature", 0.3)) max_iterations = int(os.getenv("MAX_ITERATIONS", "30")) trace_dir = os.getenv("TRACE_DIR", ".cache/traces") skills_dir = str(Path(__file__).parent / "skills") Path(trace_dir).mkdir(parents=True, exist_ok=True) store = FileSystemTraceStore(base_path=trace_dir) allowed_tools = [ "douyin_search", "douyin_user_videos", "get_content_fans_portrait", "get_account_fans_portrait", ] runner = AgentRunner( llm_call=create_openrouter_llm_call(model=model), trace_store=store, skills_dir=skills_dir, ) config = RunConfig( name="内容寻找", model=model, temperature=temperature, max_iterations=max_iterations, tools=allowed_tools, extra_llm_params={"max_tokens": 8192}, knowledge=KnowledgeConfig( enable_extraction=True, enable_completion_extraction=True, enable_injection=True, owner="content_finder_agent", default_tags={"project": "content_finder"}, default_scopes=["com.piaoquantv.supply"], default_search_types=["tool", "usecase", "definition"], default_search_owner="content_finder_agent" ) ) async for item in runner.run(messages=messages, config=config): if isinstance(item, Trace): if not trace_id_holder["id"]: trace_id_holder["id"] = item.trace_id trace_id_ready.set() logger.info(f"任务启动 [api]: trace_id={item.trace_id}") if item.status == "completed": stats["completed_tasks"] += 1 logger.info(f"任务完成 [api]: trace_id={item.trace_id}") break elif item.status == "failed": stats["failed_tasks"] += 1 logger.error(f"任务失败 [api]: trace_id={item.trace_id}, 错误={item.error_message}") break except Exception as e: stats["failed_tasks"] += 1 logger.error(f"任务异常 [api]: {e}", exc_info=True) if not trace_id_holder["id"]: trace_id_holder["id"] = f"error_{datetime.now().strftime('%Y%m%d_%H%M%S')}" trace_id_ready.set() # 启动后台任务 stats["total_tasks"] += 1 asyncio.create_task(run_and_capture()) # 等待 trace_id(最多 5 秒) try: await asyncio.wait_for(trace_id_ready.wait(), timeout=5.0) except asyncio.TimeoutError: logger.error("获取 trace_id 超时") raise HTTPException(status_code=500, detail="任务启动超时") trace_id = trace_id_holder["id"] return TaskResponse( trace_id=trace_id, status="started", query=query, message=f"任务已启动,结果将保存到 .cache/traces/{trace_id}/" ) @app.get("/health") async def health_check(): """健康检查""" return { "status": "ok", "max_concurrent_tasks": MAX_CONCURRENT_TASKS, "current_tasks": MAX_CONCURRENT_TASKS - task_semaphore._value, "scheduler_running": scheduler.running, "stats": stats } @app.get("/") async def root(): """根路径""" return { "service": "内容寻找服务", "version": "1.0.0", "endpoints": { "create_task": "POST /api/tasks", "health": "GET /health" } } # ============ 启动事件 ============ @app.on_event("startup") async def startup(): """服务启动时初始化""" logger.info("=" * 60) logger.info("内容寻找服务启动中...") logger.info(f"最大并发任务数: {MAX_CONCURRENT_TASKS}") # 配置定时任务 query_api = os.getenv("SCHEDULE_QUERY_API") if query_api: # 每 10 分钟执行一次 scheduler.add_job(scheduled_task, "cron", minute="*/10") scheduler.start() logger.info(f"定时任务已启动:每 10 分钟执行一次") logger.info(f"外部 API: {query_api}") else: logger.info("未配置 SCHEDULE_QUERY_API,定时任务未启动") logger.info("服务启动完成") logger.info("=" * 60) @app.on_event("shutdown") async def shutdown(): """服务关闭时清理""" logger.info("服务关闭中...") if scheduler.running: scheduler.shutdown() logger.info("服务已关闭") # ============ 主函数 ============ if __name__ == "__main__": import uvicorn port = int(os.getenv("PORT", "8080")) host = os.getenv("HOST", "0.0.0.0") logger.info(f"启动服务: http://{host}:{port}") uvicorn.run(app, host=host, port=port)