server.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. """
  2. 内容寻找服务
  3. 提供:
  4. 1. API 接口:POST /api/tasks - 触发内容寻找任务
  5. 2. 定时调度:每 10 分钟从数据库联表查询未处理需求并执行任务
  6. 3. 并发控制:限制最大并发任务数
  7. """
  8. import asyncio
  9. import logging
  10. import os
  11. import uuid
  12. from datetime import datetime
  13. from pathlib import Path
  14. from typing import Optional
  15. import sys
  16. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  17. from fastapi import FastAPI, HTTPException
  18. from pydantic import BaseModel
  19. from apscheduler.schedulers.asyncio import AsyncIOScheduler
  20. from dotenv import load_dotenv
  21. load_dotenv()
  22. import core
  23. from db import get_next_unprocessed_demand, create_task_record, update_task_status, update_task_on_complete
  24. from db.schedule import STATUS_RUNNING, STATUS_SUCCESS, STATUS_FAILED
  25. # 配置日志
  26. log_dir = Path(__file__).parent / '.cache'
  27. log_dir.mkdir(exist_ok=True)
  28. logging.basicConfig(
  29. level=logging.INFO,
  30. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
  31. handlers=[
  32. logging.FileHandler(log_dir / 'server.log'),
  33. logging.StreamHandler()
  34. ]
  35. )
  36. logger = logging.getLogger(__name__)
  37. # FastAPI 应用
  38. app = FastAPI(
  39. title="内容寻找服务",
  40. version="1.0.0",
  41. description="抖音内容寻找 Agent 服务"
  42. )
  43. # 定时调度器
  44. scheduler = AsyncIOScheduler()
  45. # 并发控制
  46. MAX_CONCURRENT_TASKS = int(os.getenv("MAX_CONCURRENT_TASKS", "3"))
  47. task_semaphore = asyncio.Semaphore(MAX_CONCURRENT_TASKS)
  48. # 统计信息
  49. stats = {
  50. "total_tasks": 0,
  51. "completed_tasks": 0,
  52. "failed_tasks": 0,
  53. "scheduled_tasks": 0
  54. }
  55. # ============ 数据模型 ============
  56. class TaskRequest(BaseModel):
  57. query: Optional[str] = None
  58. demand_id: Optional[int] = None
  59. class TaskResponse(BaseModel):
  60. trace_id: str
  61. status: str
  62. query: str
  63. message: str
  64. # ============ 核心函数 ============
  65. def _update_scheduled_task_complete(demand_id: int, trace_id: str, status: int) -> None:
  66. """定时任务完成时更新 trace_id 和 status,静默处理异常"""
  67. try:
  68. update_task_on_complete(demand_id, trace_id, status)
  69. except Exception as e:
  70. logger.warning(f"更新任务状态失败: {e}")
  71. async def execute_task(
  72. query: str,
  73. demand_id: Optional[int] = None,
  74. task_type: str = "api",
  75. ):
  76. """
  77. 执行任务(带并发控制)
  78. Args:
  79. query: 查询内容
  80. demand_id: 需求 id(demand_content.id,关联 demand_content 表)
  81. task_type: 任务类型("api" 或 "scheduled")
  82. """
  83. async with task_semaphore:
  84. current_concurrent = MAX_CONCURRENT_TASKS - task_semaphore._value + 1
  85. logger.info(f"任务开始 [{task_type}]: query={query[:50]}..., 当前并发={current_concurrent}/{MAX_CONCURRENT_TASKS}")
  86. start_time = datetime.now()
  87. stats["total_tasks"] += 1
  88. if task_type == "scheduled":
  89. stats["scheduled_tasks"] += 1
  90. if task_type == "scheduled" and demand_id is not None:
  91. try:
  92. update_task_status("", demand_id, STATUS_RUNNING)
  93. except Exception as e:
  94. logger.warning(f"更新任务状态为执行中失败: {e}")
  95. try:
  96. result = await core.run_agent(
  97. query, demand_id=demand_id, stream_output=False
  98. )
  99. duration = (datetime.now() - start_time).total_seconds()
  100. if result["status"] == "completed":
  101. stats["completed_tasks"] += 1
  102. logger.info(f"任务完成 [{task_type}]: trace_id={result['trace_id']}, 耗时={duration:.1f}s")
  103. if task_type == "scheduled" and demand_id is not None:
  104. _update_scheduled_task_complete(demand_id, result["trace_id"], STATUS_SUCCESS)
  105. else:
  106. stats["failed_tasks"] += 1
  107. logger.error(f"任务失败 [{task_type}]: trace_id={result.get('trace_id')}, 错误={result.get('error')}, 耗时={duration:.1f}s")
  108. if task_type == "scheduled" and demand_id is not None:
  109. _update_scheduled_task_complete(demand_id, result.get("trace_id") or "", STATUS_FAILED)
  110. except Exception as e:
  111. stats["failed_tasks"] += 1
  112. duration = (datetime.now() - start_time).total_seconds()
  113. logger.error(f"任务异常 [{task_type}]: {e}, 耗时={duration:.1f}s", exc_info=True)
  114. if task_type == "scheduled" and demand_id is not None:
  115. _update_scheduled_task_complete(demand_id, "", STATUS_FAILED)
  116. async def scheduled_task():
  117. """
  118. 定时任务:每 10 分钟执行一次
  119. 流程:
  120. 1. 联表查询 demand_content + demand_find_task,获取创建时间最早的未处理的 demand_content
  121. 2. 在 demand_find_task 新增记录
  122. 3. 调用 execute_task 执行
  123. """
  124. logger.info("定时任务触发")
  125. demand = get_next_unprocessed_demand()
  126. if not demand:
  127. logger.info("定时任务跳过:无待处理需求")
  128. return
  129. query = demand.get("query") or ""
  130. if not query:
  131. logger.info("定时任务跳过:该需求的 query 为空")
  132. return
  133. demand_content_id = demand.get("demand_content_id")
  134. if demand_content_id is None:
  135. logger.warning("定时任务跳过:demand_content_id 为空")
  136. return
  137. create_task_record(demand_content_id) # trace_id 初始为空,完成后更新
  138. asyncio.create_task(execute_task(query=query, demand_id=demand_content_id, task_type="scheduled"))
  139. # ============ API 接口 ============
  140. @app.post("/api/tasks", response_model=TaskResponse)
  141. async def create_task(request: TaskRequest):
  142. """
  143. 创建内容寻找任务
  144. Args:
  145. request.query: 查询内容(可选,不传则使用默认值)
  146. Returns:
  147. {
  148. "trace_id": "20260317_103046_xyz789",
  149. "status": "started",
  150. "query": "...",
  151. "message": "任务已启动,结果将保存到 .cache/traces/xxx/"
  152. }
  153. """
  154. # 获取 query 和 demand_id
  155. query = request.query or core.DEFAULT_QUERY
  156. demand_id = request.demand_id
  157. # 用 Event 等待 trace_id
  158. trace_id_ready = asyncio.Event()
  159. trace_id_holder = {"id": None}
  160. async def run_and_capture():
  161. try:
  162. # 获取第一个 Trace 对象来获取 trace_id
  163. from agent import Trace
  164. async with task_semaphore:
  165. # 重新构建 runner 来获取 trace_id
  166. from agent import AgentRunner, RunConfig, FileSystemTraceStore
  167. from agent.llm import create_openrouter_llm_call
  168. from agent.llm.prompts import SimplePrompt
  169. from agent.tools.builtin.knowledge import KnowledgeConfig
  170. prompt_path = Path(__file__).parent / "content_finder.prompt"
  171. prompt = SimplePrompt(prompt_path)
  172. trace_dir = os.getenv("TRACE_DIR", ".cache/traces")
  173. demand_id_str = str(demand_id) if demand_id is not None else ""
  174. messages = prompt.build_messages(query=query, trace_dir=trace_dir, demand_id=demand_id_str)
  175. api_key = os.getenv("OPEN_ROUTER_API_KEY")
  176. model_name = prompt.config.get("model", "sonnet-4.6")
  177. model = os.getenv("MODEL", f"anthropic/claude-{model_name}")
  178. temperature = float(prompt.config.get("temperature", 0.3))
  179. max_iterations = int(os.getenv("MAX_ITERATIONS", "30"))
  180. trace_dir = os.getenv("TRACE_DIR", ".cache/traces")
  181. skills_dir = str(Path(__file__).parent / "skills")
  182. Path(trace_dir).mkdir(parents=True, exist_ok=True)
  183. store = FileSystemTraceStore(base_path=trace_dir)
  184. allowed_tools = [
  185. "douyin_search",
  186. "douyin_user_videos",
  187. "get_content_fans_portrait",
  188. "get_account_fans_portrait",
  189. "store_results_mysql",
  190. ]
  191. runner = AgentRunner(
  192. llm_call=create_openrouter_llm_call(model=model),
  193. trace_store=store,
  194. skills_dir=skills_dir,
  195. )
  196. config = RunConfig(
  197. name="内容寻找",
  198. model=model,
  199. temperature=temperature,
  200. max_iterations=max_iterations,
  201. tools=allowed_tools,
  202. extra_llm_params={"max_tokens": 8192},
  203. knowledge=KnowledgeConfig(
  204. enable_extraction=True,
  205. enable_completion_extraction=True,
  206. enable_injection=True,
  207. owner="content_finder_agent",
  208. default_tags={"project": "content_finder"},
  209. default_scopes=["com.piaoquantv.supply"],
  210. default_search_types=["tool", "usecase", "definition"],
  211. default_search_owner="content_finder_agent"
  212. )
  213. )
  214. async for item in runner.run(messages=messages, config=config):
  215. if isinstance(item, Trace):
  216. if not trace_id_holder["id"]:
  217. trace_id_holder["id"] = item.trace_id
  218. trace_id_ready.set()
  219. logger.info(f"任务启动 [api]: trace_id={item.trace_id}")
  220. if item.status == "completed":
  221. stats["completed_tasks"] += 1
  222. logger.info(f"任务完成 [api]: trace_id={item.trace_id}")
  223. break
  224. elif item.status == "failed":
  225. stats["failed_tasks"] += 1
  226. logger.error(f"任务失败 [api]: trace_id={item.trace_id}, 错误={item.error_message}")
  227. break
  228. except Exception as e:
  229. stats["failed_tasks"] += 1
  230. logger.error(f"任务异常 [api]: {e}", exc_info=True)
  231. if not trace_id_holder["id"]:
  232. trace_id_holder["id"] = f"error_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
  233. trace_id_ready.set()
  234. # 启动后台任务
  235. stats["total_tasks"] += 1
  236. asyncio.create_task(run_and_capture())
  237. # 等待 trace_id(最多 5 秒)
  238. try:
  239. await asyncio.wait_for(trace_id_ready.wait(), timeout=5.0)
  240. except asyncio.TimeoutError:
  241. logger.error("获取 trace_id 超时")
  242. raise HTTPException(status_code=500, detail="任务启动超时")
  243. trace_id = trace_id_holder["id"]
  244. return TaskResponse(
  245. trace_id=trace_id,
  246. status="started",
  247. query=query,
  248. message=f"任务已启动,结果将保存到 .cache/traces/{trace_id}/"
  249. )
  250. @app.get("/health")
  251. async def health_check():
  252. """健康检查"""
  253. return {
  254. "status": "ok",
  255. "max_concurrent_tasks": MAX_CONCURRENT_TASKS,
  256. "current_tasks": MAX_CONCURRENT_TASKS - task_semaphore._value,
  257. "scheduler_running": scheduler.running,
  258. "stats": stats
  259. }
  260. @app.get("/")
  261. async def root():
  262. """根路径"""
  263. return {
  264. "service": "内容寻找服务",
  265. "version": "1.0.0",
  266. "endpoints": {
  267. "create_task": "POST /api/tasks",
  268. "health": "GET /health"
  269. }
  270. }
  271. # ============ 启动事件 ============
  272. @app.on_event("startup")
  273. async def startup():
  274. """服务启动时初始化"""
  275. logger.info("=" * 60)
  276. logger.info("内容寻找服务启动中...")
  277. logger.info(f"最大并发任务数: {MAX_CONCURRENT_TASKS}")
  278. # 配置定时任务(从 demand_content 联表查询未处理需求,无需外部 API)
  279. scheduler.add_job(scheduled_task, "cron", minute="*/10")
  280. scheduler.start()
  281. asyncio.create_task(scheduled_task())
  282. logger.info("定时任务已启动:启动后立即执行一次,之后每 10 分钟执行(从数据库获取待处理需求)")
  283. logger.info("服务启动完成")
  284. logger.info("=" * 60)
  285. @app.on_event("shutdown")
  286. async def shutdown():
  287. """服务关闭时清理"""
  288. logger.info("服务关闭中...")
  289. if scheduler.running:
  290. scheduler.shutdown()
  291. logger.info("服务已关闭")
  292. # ============ 主函数 ============
  293. if __name__ == "__main__":
  294. import uvicorn
  295. port = int(os.getenv("PORT", "8080"))
  296. host = os.getenv("HOST", "0.0.0.0")
  297. logger.info(f"启动服务: http://{host}:{port}")
  298. uvicorn.run(app, host=host, port=port)