| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368 |
- """
- 内容寻找服务
- 提供:
- 1. API 接口:POST /api/tasks - 触发内容寻找任务
- 2. 定时调度:每 10 分钟调用外部 API 获取 query 并执行任务
- 3. 并发控制:限制最大并发任务数
- """
- import asyncio
- import logging
- import os
- from datetime import datetime
- from pathlib import Path
- from typing import Optional
- import sys
- sys.path.insert(0, str(Path(__file__).parent.parent.parent))
- import httpx
- from fastapi import FastAPI, HTTPException
- from pydantic import BaseModel
- from apscheduler.schedulers.asyncio import AsyncIOScheduler
- from dotenv import load_dotenv
- load_dotenv()
- import core
- # 配置日志
- log_dir = Path(__file__).parent / '.cache'
- log_dir.mkdir(exist_ok=True)
- logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
- handlers=[
- logging.FileHandler(log_dir / 'server.log'),
- logging.StreamHandler()
- ]
- )
- logger = logging.getLogger(__name__)
- # FastAPI 应用
- app = FastAPI(
- title="内容寻找服务",
- version="1.0.0",
- description="抖音内容寻找 Agent 服务"
- )
- # 定时调度器
- scheduler = AsyncIOScheduler()
- # 并发控制
- MAX_CONCURRENT_TASKS = int(os.getenv("MAX_CONCURRENT_TASKS", "3"))
- task_semaphore = asyncio.Semaphore(MAX_CONCURRENT_TASKS)
- # 统计信息
- stats = {
- "total_tasks": 0,
- "completed_tasks": 0,
- "failed_tasks": 0,
- "scheduled_tasks": 0
- }
- # ============ 数据模型 ============
- class TaskRequest(BaseModel):
- query: Optional[str] = None
- class TaskResponse(BaseModel):
- trace_id: str
- status: str
- query: str
- message: str
- # ============ 核心函数 ============
- async def execute_task(query: str, task_type: str = "api"):
- """
- 执行任务(带并发控制)
- Args:
- query: 查询内容
- task_type: 任务类型("api" 或 "scheduled")
- """
- async with task_semaphore:
- current_concurrent = MAX_CONCURRENT_TASKS - task_semaphore._value + 1
- logger.info(f"任务开始 [{task_type}]: query={query[:50]}..., 当前并发={current_concurrent}/{MAX_CONCURRENT_TASKS}")
- start_time = datetime.now()
- stats["total_tasks"] += 1
- if task_type == "scheduled":
- stats["scheduled_tasks"] += 1
- try:
- # 执行 agent(不流式输出)
- result = await core.run_agent(query, stream_output=False)
- duration = (datetime.now() - start_time).total_seconds()
- if result["status"] == "completed":
- stats["completed_tasks"] += 1
- logger.info(f"任务完成 [{task_type}]: trace_id={result['trace_id']}, 耗时={duration:.1f}s")
- else:
- stats["failed_tasks"] += 1
- logger.error(f"任务失败 [{task_type}]: trace_id={result.get('trace_id')}, 错误={result.get('error')}, 耗时={duration:.1f}s")
- except Exception as e:
- stats["failed_tasks"] += 1
- duration = (datetime.now() - start_time).total_seconds()
- logger.error(f"任务异常 [{task_type}]: {e}, 耗时={duration:.1f}s", exc_info=True)
- async def scheduled_task():
- """
- 定时任务:每 10 分钟执行一次
- 流程:
- 1. 调用外部 API 获取 query
- 2. 如果成功,执行任务
- 3. 如果失败,跳过本次执行(不使用兜底)
- """
- logger.info("定时任务触发")
- try:
- # 1. 调用外部 API 获取 query
- query_api = os.getenv("SCHEDULE_QUERY_API")
- if not query_api:
- logger.warning("未配置 SCHEDULE_QUERY_API,跳过定时任务")
- return
- api_key = os.getenv("SCHEDULE_QUERY_API_KEY", "")
- timeout = float(os.getenv("SCHEDULE_QUERY_API_TIMEOUT", "10.0"))
- async with httpx.AsyncClient() as client:
- headers = {}
- if api_key:
- headers["Authorization"] = f"Bearer {api_key}"
- logger.info(f"调用外部 API: {query_api}")
- response = await client.get(
- query_api,
- headers=headers,
- timeout=timeout
- )
- response.raise_for_status()
- data = response.json()
- # 2. 提取 query
- query = data.get("query")
- if not query:
- logger.info("定时任务跳过:外部 API 返回的 query 为空")
- return
- # 3. 执行任务
- logger.info(f"定时任务启动: query={query[:50]}...")
- asyncio.create_task(execute_task(query, task_type="scheduled"))
- except httpx.HTTPStatusError as e:
- logger.error(f"定时任务失败:外部 API 返回错误 {e.response.status_code}: {e.response.text}")
- except httpx.RequestError as e:
- logger.error(f"定时任务失败:外部 API 请求失败: {e}")
- except httpx.TimeoutException:
- logger.error(f"定时任务失败:外部 API 请求超时")
- except Exception as e:
- logger.error(f"定时任务失败:未知错误: {e}", exc_info=True)
- # ============ API 接口 ============
- @app.post("/api/tasks", response_model=TaskResponse)
- async def create_task(request: TaskRequest):
- """
- 创建内容寻找任务
- Args:
- request.query: 查询内容(可选,不传则使用默认值)
- Returns:
- {
- "trace_id": "20260317_103046_xyz789",
- "status": "started",
- "query": "...",
- "message": "任务已启动,结果将保存到 .cache/traces/xxx/"
- }
- """
- # 获取 query
- query = request.query or core.DEFAULT_QUERY
- # 用 Event 等待 trace_id
- trace_id_ready = asyncio.Event()
- trace_id_holder = {"id": None}
- async def run_and_capture():
- try:
- # 获取第一个 Trace 对象来获取 trace_id
- from agent import Trace
- async with task_semaphore:
- # 重新构建 runner 来获取 trace_id
- from agent import AgentRunner, RunConfig, FileSystemTraceStore
- from agent.llm import create_openrouter_llm_call
- from agent.llm.prompts import SimplePrompt
- from agent.tools.builtin.knowledge import KnowledgeConfig
- prompt_path = Path(__file__).parent / "content_finder.prompt"
- prompt = SimplePrompt(prompt_path)
- messages = prompt.build_messages(query=query)
- api_key = os.getenv("OPEN_ROUTER_API_KEY")
- model_name = prompt.config.get("model", "sonnet-4.6")
- model = os.getenv("MODEL", f"anthropic/claude-{model_name}")
- temperature = float(prompt.config.get("temperature", 0.3))
- max_iterations = int(os.getenv("MAX_ITERATIONS", "30"))
- trace_dir = os.getenv("TRACE_DIR", ".cache/traces")
- skills_dir = str(Path(__file__).parent / "skills")
- Path(trace_dir).mkdir(parents=True, exist_ok=True)
- store = FileSystemTraceStore(base_path=trace_dir)
- allowed_tools = [
- "douyin_search",
- "douyin_user_videos",
- "get_content_fans_portrait",
- "get_account_fans_portrait",
- ]
- runner = AgentRunner(
- llm_call=create_openrouter_llm_call(model=model),
- trace_store=store,
- skills_dir=skills_dir,
- )
- config = RunConfig(
- name="内容寻找",
- model=model,
- temperature=temperature,
- max_iterations=max_iterations,
- tools=allowed_tools,
- extra_llm_params={"max_tokens": 8192},
- knowledge=KnowledgeConfig(
- enable_extraction=True,
- enable_completion_extraction=True,
- enable_injection=True,
- owner="content_finder_agent",
- default_tags={"project": "content_finder"},
- default_scopes=["com.piaoquantv.supply"],
- default_search_types=["tool", "usecase", "definition"],
- default_search_owner="content_finder_agent"
- )
- )
- async for item in runner.run(messages=messages, config=config):
- if isinstance(item, Trace):
- if not trace_id_holder["id"]:
- trace_id_holder["id"] = item.trace_id
- trace_id_ready.set()
- logger.info(f"任务启动 [api]: trace_id={item.trace_id}")
- if item.status == "completed":
- stats["completed_tasks"] += 1
- logger.info(f"任务完成 [api]: trace_id={item.trace_id}")
- break
- elif item.status == "failed":
- stats["failed_tasks"] += 1
- logger.error(f"任务失败 [api]: trace_id={item.trace_id}, 错误={item.error_message}")
- break
- except Exception as e:
- stats["failed_tasks"] += 1
- logger.error(f"任务异常 [api]: {e}", exc_info=True)
- if not trace_id_holder["id"]:
- trace_id_holder["id"] = f"error_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
- trace_id_ready.set()
- # 启动后台任务
- stats["total_tasks"] += 1
- asyncio.create_task(run_and_capture())
- # 等待 trace_id(最多 5 秒)
- try:
- await asyncio.wait_for(trace_id_ready.wait(), timeout=5.0)
- except asyncio.TimeoutError:
- logger.error("获取 trace_id 超时")
- raise HTTPException(status_code=500, detail="任务启动超时")
- trace_id = trace_id_holder["id"]
- return TaskResponse(
- trace_id=trace_id,
- status="started",
- query=query,
- message=f"任务已启动,结果将保存到 .cache/traces/{trace_id}/"
- )
- @app.get("/health")
- async def health_check():
- """健康检查"""
- return {
- "status": "ok",
- "max_concurrent_tasks": MAX_CONCURRENT_TASKS,
- "current_tasks": MAX_CONCURRENT_TASKS - task_semaphore._value,
- "scheduler_running": scheduler.running,
- "stats": stats
- }
- @app.get("/")
- async def root():
- """根路径"""
- return {
- "service": "内容寻找服务",
- "version": "1.0.0",
- "endpoints": {
- "create_task": "POST /api/tasks",
- "health": "GET /health"
- }
- }
- # ============ 启动事件 ============
- @app.on_event("startup")
- async def startup():
- """服务启动时初始化"""
- logger.info("=" * 60)
- logger.info("内容寻找服务启动中...")
- logger.info(f"最大并发任务数: {MAX_CONCURRENT_TASKS}")
- # 配置定时任务
- query_api = os.getenv("SCHEDULE_QUERY_API")
- if query_api:
- # 每 10 分钟执行一次
- scheduler.add_job(scheduled_task, "cron", minute="*/10")
- scheduler.start()
- logger.info(f"定时任务已启动:每 10 分钟执行一次")
- logger.info(f"外部 API: {query_api}")
- else:
- logger.info("未配置 SCHEDULE_QUERY_API,定时任务未启动")
- logger.info("服务启动完成")
- logger.info("=" * 60)
- @app.on_event("shutdown")
- async def shutdown():
- """服务关闭时清理"""
- logger.info("服务关闭中...")
- if scheduler.running:
- scheduler.shutdown()
- logger.info("服务已关闭")
- # ============ 主函数 ============
- if __name__ == "__main__":
- import uvicorn
- port = int(os.getenv("PORT", "8080"))
- host = os.getenv("HOST", "0.0.0.0")
- logger.info(f"启动服务: http://{host}:{port}")
- uvicorn.run(app, host=host, port=port)
|