""" 内容寻找服务 提供: 1. API 接口:POST /api/tasks - 触发内容寻找任务 2. 定时调度:每 10 分钟从数据库联表查询未处理需求并执行任务 3. 并发控制:限制最大并发任务数 """ import asyncio import logging import os import uuid from datetime import datetime from pathlib import Path from typing import Optional import sys sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from fastapi import FastAPI, HTTPException from pydantic import BaseModel from apscheduler.schedulers.asyncio import AsyncIOScheduler from dotenv import load_dotenv load_dotenv() import core from db import get_next_unprocessed_demand, create_task_record, update_task_status, update_task_on_complete from db.schedule import STATUS_RUNNING, STATUS_SUCCESS, STATUS_FAILED # 配置日志 log_dir = Path(__file__).parent / '.cache' log_dir.mkdir(exist_ok=True) logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(log_dir / 'server.log'), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) # FastAPI 应用 app = FastAPI( title="内容寻找服务", version="1.0.0", description="抖音内容寻找 Agent 服务" ) # 定时调度器 scheduler = AsyncIOScheduler() # 并发控制 MAX_CONCURRENT_TASKS = int(os.getenv("MAX_CONCURRENT_TASKS", "3")) task_semaphore = asyncio.Semaphore(MAX_CONCURRENT_TASKS) # 统计信息 stats = { "total_tasks": 0, "completed_tasks": 0, "failed_tasks": 0, "scheduled_tasks": 0 } # ============ 数据模型 ============ class TaskRequest(BaseModel): query: Optional[str] = None demand_id: Optional[int] = None class TaskResponse(BaseModel): trace_id: str status: str query: str message: str # ============ 核心函数 ============ def _update_scheduled_task_complete(demand_id: int, trace_id: str, status: int) -> None: """定时任务完成时更新 trace_id 和 status,静默处理异常""" try: update_task_on_complete(demand_id, trace_id, status) except Exception as e: logger.warning(f"更新任务状态失败: {e}") async def execute_task( query: str, demand_id: Optional[int] = None, task_type: str = "api", ): """ 执行任务(带并发控制) Args: query: 查询内容 demand_id: 需求 id(demand_content.id,关联 demand_content 表) task_type: 任务类型("api" 或 "scheduled") """ async with task_semaphore: current_concurrent = MAX_CONCURRENT_TASKS - task_semaphore._value + 1 logger.info(f"任务开始 [{task_type}]: query={query[:50]}..., 当前并发={current_concurrent}/{MAX_CONCURRENT_TASKS}") start_time = datetime.now() stats["total_tasks"] += 1 if task_type == "scheduled": stats["scheduled_tasks"] += 1 if task_type == "scheduled" and demand_id is not None: try: update_task_status("", demand_id, STATUS_RUNNING) except Exception as e: logger.warning(f"更新任务状态为执行中失败: {e}") try: result = await core.run_agent( query, demand_id=demand_id, stream_output=False, trace_id=None ) duration = (datetime.now() - start_time).total_seconds() if result["status"] == "completed": stats["completed_tasks"] += 1 logger.info(f"任务完成 [{task_type}]: trace_id={result['trace_id']}, 耗时={duration:.1f}s") if task_type == "scheduled" and demand_id is not None: _update_scheduled_task_complete(demand_id, result["trace_id"], STATUS_SUCCESS) else: stats["failed_tasks"] += 1 logger.error(f"任务失败 [{task_type}]: trace_id={result.get('trace_id')}, 错误={result.get('error')}, 耗时={duration:.1f}s") if task_type == "scheduled" and demand_id is not None: _update_scheduled_task_complete(demand_id, result.get("trace_id") or "", STATUS_FAILED) except Exception as e: stats["failed_tasks"] += 1 duration = (datetime.now() - start_time).total_seconds() logger.error(f"任务异常 [{task_type}]: {e}, 耗时={duration:.1f}s", exc_info=True) if task_type == "scheduled" and demand_id is not None: _update_scheduled_task_complete(demand_id, "", STATUS_FAILED) async def scheduled_task(): """ 定时任务:每 10 分钟执行一次 流程: 1. 联表查询 demand_content + demand_find_task,获取创建时间最早的未处理的 demand_content 2. 在 demand_find_task 新增记录 3. 调用 execute_task 执行 """ logger.info("定时任务触发") demand = get_next_unprocessed_demand() if not demand: logger.info("定时任务跳过:无待处理需求") return query = demand.get("query") or "" if not query: logger.info("定时任务跳过:该需求的 query 为空") return demand_content_id = demand.get("demand_content_id") if demand_content_id is None: logger.warning("定时任务跳过:demand_content_id 为空") return create_task_record(demand_content_id) # trace_id 初始为空,完成后更新 asyncio.create_task(execute_task(query=query, demand_id=demand_content_id, task_type="scheduled")) # ============ API 接口 ============ @app.post("/api/tasks", response_model=TaskResponse) async def create_task(request: TaskRequest): """ 创建内容寻找任务 Args: request.query: 查询内容(可选,不传则使用默认值) Returns: { "trace_id": "20260317_103046_xyz789", "status": "started", "query": "...", "message": "任务已启动,结果将保存到 .cache/traces/xxx/" } """ # 获取 query 和 demand_id query = request.query or core.DEFAULT_QUERY demand_id = request.demand_id # 用 Event 等待 trace_id trace_id_ready = asyncio.Event() trace_id_holder = {"id": None} async def run_and_capture(): try: # 获取第一个 Trace 对象来获取 trace_id from agent import Trace async with task_semaphore: # 重新构建 runner 来获取 trace_id from agent import AgentRunner, RunConfig, FileSystemTraceStore from agent.llm import create_openrouter_llm_call from agent.llm.prompts import SimplePrompt from agent.tools.builtin.knowledge import KnowledgeConfig prompt_path = Path(__file__).parent / "content_finder.prompt" prompt = SimplePrompt(prompt_path) trace_dir = os.getenv("TRACE_DIR", ".cache/traces") demand_id_str = str(demand_id) if demand_id is not None else "" messages = prompt.build_messages(query=query, trace_dir=trace_dir, demand_id=demand_id_str) api_key = os.getenv("OPEN_ROUTER_API_KEY") model_name = prompt.config.get("model", "sonnet-4.6") model = os.getenv("MODEL", f"anthropic/claude-{model_name}") temperature = float(prompt.config.get("temperature", 0.3)) max_iterations = int(os.getenv("MAX_ITERATIONS", "30")) trace_dir = os.getenv("TRACE_DIR", ".cache/traces") skills_dir = str(Path(__file__).parent / "skills") Path(trace_dir).mkdir(parents=True, exist_ok=True) store = FileSystemTraceStore(base_path=trace_dir) allowed_tools = [ "douyin_search", "douyin_user_videos", "get_content_fans_portrait", "get_account_fans_portrait", "store_results_mysql", ] runner = AgentRunner( llm_call=create_openrouter_llm_call(model=model), trace_store=store, skills_dir=skills_dir, ) config = RunConfig( name="内容寻找", model=model, temperature=temperature, max_iterations=max_iterations, tools=allowed_tools, extra_llm_params={"max_tokens": 8192}, knowledge=KnowledgeConfig( enable_extraction=True, enable_completion_extraction=True, enable_injection=True, owner="content_finder_agent", default_tags={"project": "content_finder"}, default_scopes=["com.piaoquantv.supply"], default_search_types=["tool", "usecase", "definition"], default_search_owner="content_finder_agent" ) ) async for item in runner.run(messages=messages, config=config): if isinstance(item, Trace): if not trace_id_holder["id"]: trace_id_holder["id"] = item.trace_id trace_id_ready.set() logger.info(f"任务启动 [api]: trace_id={item.trace_id}") if item.status == "completed": stats["completed_tasks"] += 1 logger.info(f"任务完成 [api]: trace_id={item.trace_id}") break elif item.status == "failed": stats["failed_tasks"] += 1 logger.error(f"任务失败 [api]: trace_id={item.trace_id}, 错误={item.error_message}") break except Exception as e: stats["failed_tasks"] += 1 logger.error(f"任务异常 [api]: {e}", exc_info=True) if not trace_id_holder["id"]: trace_id_holder["id"] = f"error_{datetime.now().strftime('%Y%m%d_%H%M%S')}" trace_id_ready.set() # 启动后台任务 stats["total_tasks"] += 1 asyncio.create_task(run_and_capture()) # 等待 trace_id(最多 5 秒) try: await asyncio.wait_for(trace_id_ready.wait(), timeout=5.0) except asyncio.TimeoutError: logger.error("获取 trace_id 超时") raise HTTPException(status_code=500, detail="任务启动超时") trace_id = trace_id_holder["id"] return TaskResponse( trace_id=trace_id, status="started", query=query, message=f"任务已启动,结果将保存到 .cache/traces/{trace_id}/" ) @app.get("/health") async def health_check(): """健康检查""" return { "status": "ok", "max_concurrent_tasks": MAX_CONCURRENT_TASKS, "current_tasks": MAX_CONCURRENT_TASKS - task_semaphore._value, "scheduler_running": scheduler.running, "stats": stats } @app.get("/") async def root(): """根路径""" return { "service": "内容寻找服务", "version": "1.0.0", "endpoints": { "create_task": "POST /api/tasks", "health": "GET /health" } } # ============ 启动事件 ============ @app.on_event("startup") async def startup(): """服务启动时初始化""" logger.info("=" * 60) logger.info("内容寻找服务启动中...") logger.info(f"最大并发任务数: {MAX_CONCURRENT_TASKS}") # 配置定时任务(从 demand_content 联表查询未处理需求,无需外部 API) scheduler.add_job(scheduled_task, "cron", minute="*/10") scheduler.start() logger.info("定时任务已启动:每 10 分钟执行一次(从数据库获取待处理需求)") logger.info("服务启动完成") logger.info("=" * 60) @app.on_event("shutdown") async def shutdown(): """服务关闭时清理""" logger.info("服务关闭中...") if scheduler.running: scheduler.shutdown() logger.info("服务已关闭") # ============ 主函数 ============ if __name__ == "__main__": import uvicorn port = int(os.getenv("PORT", "8080")) host = os.getenv("HOST", "0.0.0.0") logger.info(f"启动服务: http://{host}:{port}") uvicorn.run(app, host=host, port=port)