server.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481
  1. """
  2. 内容寻找服务
  3. 提供:
  4. 1. API 接口:POST /api/tasks - 触发内容寻找任务
  5. 2. 定时调度:启动后先恢复 demand_find_task 中 status=执行中 的任务;之后每 2 分钟轮询一次,
  6. 若当前无任务在执行,则从 demand_content 取当天(dt=YYYYMMDD)、未建任务记录且 score 最高的一条执行(不区分品类)
  7. 3. 并发控制:限制最大并发任务数;定时侧若已有任务在执行则跳过本次轮询
  8. 4. 单次寻找任务最长执行 25 分钟,超时记为失败并回写 demand_find_task
  9. """
  10. import asyncio
  11. import logging
  12. import os
  13. import uuid
  14. from datetime import datetime
  15. from pathlib import Path
  16. from typing import Optional
  17. import sys
  18. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  19. from fastapi import FastAPI, HTTPException
  20. from pydantic import BaseModel
  21. from apscheduler.schedulers.asyncio import AsyncIOScheduler
  22. from zoneinfo import ZoneInfo
  23. from dotenv import load_dotenv
  24. load_dotenv()
  25. import core
  26. from db import (
  27. create_task_record,
  28. get_first_running_task,
  29. get_one_today_unprocessed_demand,
  30. update_task_status,
  31. update_task_on_complete,
  32. )
  33. from db.schedule import STATUS_RUNNING, STATUS_SUCCESS, STATUS_FAILED
  34. # 配置日志
  35. log_dir = Path(__file__).parent / '.cache'
  36. log_dir.mkdir(exist_ok=True)
  37. logging.basicConfig(
  38. level=logging.INFO,
  39. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
  40. handlers=[
  41. logging.FileHandler(log_dir / 'server.log'),
  42. logging.StreamHandler()
  43. ]
  44. )
  45. logger = logging.getLogger(__name__)
  46. # FastAPI 应用
  47. app = FastAPI(
  48. title="内容寻找服务",
  49. version="1.0.0",
  50. description="抖音内容寻找 Agent 服务"
  51. )
  52. # 定时调度器(默认用中国时区,避免容器 UTC 导致错过预期时间点)
  53. SCHEDULER_TIMEZONE = os.getenv("SCHEDULER_TIMEZONE", os.getenv("TZ", "Asia/Shanghai"))
  54. SCHEDULER_TZ = ZoneInfo(SCHEDULER_TIMEZONE)
  55. scheduler = AsyncIOScheduler(timezone=SCHEDULER_TZ)
  56. # 并发控制
  57. MAX_CONCURRENT_TASKS = int(os.getenv("MAX_CONCURRENT_TASKS", "1"))
  58. task_semaphore = asyncio.Semaphore(MAX_CONCURRENT_TASKS)
  59. # 定时:派发间隔(秒)、单次任务超时(秒,默认 15 分钟)
  60. # - 为避免启动时同时派发多个任务导致潜在重复处理,默认每 30s 只派发 1 条;
  61. # 通过持续派发逐步填满并发槽,直到达到 MAX_CONCURRENT_TASKS。
  62. SCHEDULE_DISPATCH_INTERVAL_SECONDS = int(os.getenv("SCHEDULE_DISPATCH_INTERVAL_SECONDS", "30"))
  63. TASK_TIMEOUT_SECONDS = int(os.getenv("SCHEDULE_TASK_TIMEOUT_SECONDS", "1500"))
  64. # 统计信息
  65. stats = {
  66. "total_tasks": 0,
  67. "completed_tasks": 0,
  68. "failed_tasks": 0,
  69. "scheduled_tasks": 0
  70. }
  71. # ============ 数据模型 ============
  72. class TaskRequest(BaseModel):
  73. query: Optional[str] = None
  74. demand_id: Optional[int] = None
  75. suggestion: Optional[str] = None
  76. class TaskResponse(BaseModel):
  77. trace_id: str
  78. status: str
  79. query: str
  80. message: str
  81. # ============ 核心函数 ============
  82. def _update_scheduled_task_complete(demand_id: int, trace_id: str, status: int) -> None:
  83. """定时任务完成时更新 trace_id 和 status,静默处理异常"""
  84. try:
  85. update_task_on_complete(demand_id, trace_id, status)
  86. except Exception as e:
  87. logger.warning(f"更新任务状态失败: {e}")
  88. async def execute_task(
  89. query: str,
  90. demand_id: Optional[int] = None,
  91. suggestion: str = "",
  92. task_type: str = "api",
  93. ):
  94. """
  95. 执行任务(带并发控制)
  96. Args:
  97. query: 查询内容
  98. demand_id: 需求 id(demand_content.id,关联 demand_content 表)
  99. suggestion: 补充信息(定时任务与 demand_content.suggestion 一致)
  100. task_type: 任务类型("api" 或 "scheduled")
  101. """
  102. async with task_semaphore:
  103. current_concurrent = MAX_CONCURRENT_TASKS - task_semaphore._value + 1
  104. logger.info(f"任务开始 [{task_type}]: query={query[:50]}..., 当前并发={current_concurrent}/{MAX_CONCURRENT_TASKS}")
  105. start_time = datetime.now(SCHEDULER_TZ)
  106. stats["total_tasks"] += 1
  107. if task_type == "scheduled":
  108. stats["scheduled_tasks"] += 1
  109. if task_type == "scheduled" and demand_id is not None:
  110. try:
  111. update_task_status("", demand_id, STATUS_RUNNING)
  112. except Exception as e:
  113. logger.warning(f"更新任务状态为执行中失败: {e}")
  114. try:
  115. result = await asyncio.wait_for(
  116. core.run_agent(
  117. query,
  118. demand_id=demand_id,
  119. suggestion=suggestion or None,
  120. stream_output=False,
  121. log_assistant_text=True,
  122. ),
  123. timeout=float(TASK_TIMEOUT_SECONDS),
  124. )
  125. duration = (datetime.now(SCHEDULER_TZ) - start_time).total_seconds()
  126. if result["status"] == "completed":
  127. stats["completed_tasks"] += 1
  128. logger.info(f"任务完成 [{task_type}]: trace_id={result['trace_id']}, 耗时={duration:.1f}s")
  129. if task_type == "scheduled" and demand_id is not None:
  130. _update_scheduled_task_complete(demand_id, result["trace_id"], STATUS_SUCCESS)
  131. else:
  132. stats["failed_tasks"] += 1
  133. logger.error(f"任务失败 [{task_type}]: trace_id={result.get('trace_id')}, 错误={result.get('error')}, 耗时={duration:.1f}s")
  134. if task_type == "scheduled" and demand_id is not None:
  135. _update_scheduled_task_complete(demand_id, result.get("trace_id") or "", STATUS_FAILED)
  136. except asyncio.TimeoutError:
  137. stats["failed_tasks"] += 1
  138. duration = (datetime.now(SCHEDULER_TZ) - start_time).total_seconds()
  139. logger.error(
  140. f"任务超时 [{task_type}]: 超过 {TASK_TIMEOUT_SECONDS}s,记为失败, 耗时={duration:.1f}s"
  141. )
  142. if task_type == "scheduled" and demand_id is not None:
  143. _update_scheduled_task_complete(demand_id, "", STATUS_FAILED)
  144. except Exception as e:
  145. stats["failed_tasks"] += 1
  146. duration = (datetime.now(SCHEDULER_TZ) - start_time).total_seconds()
  147. logger.error(f"任务异常 [{task_type}]: {e}, 耗时={duration:.1f}s", exc_info=True)
  148. if task_type == "scheduled" and demand_id is not None:
  149. _update_scheduled_task_complete(demand_id, "", STATUS_FAILED)
  150. def _today_dt_int() -> int:
  151. """当天 demand_content.dt 约定为 YYYYMMDD 整数(如 20260402),与定时器时区一致。"""
  152. return int(datetime.now(SCHEDULER_TZ).strftime("%Y%m%d"))
  153. def _has_running_content_task() -> bool:
  154. """
  155. 本进程内是否有内容寻找任务正在执行(占用并发槽)。
  156. 与 execute_task 共用 task_semaphore,含 API 触发与定时触发。
  157. """
  158. return task_semaphore._value != MAX_CONCURRENT_TASKS
  159. async def scheduled_tick():
  160. """
  161. 按 SCHEDULE_DISPATCH_INTERVAL_SECONDS 派发:若当前并发有空槽,则从 demand_content 取
  162. 当天(dt=今日)、尚未出现在 demand_find_task 中且 score 最高的一条需求并执行。
  163. """
  164. logger.info("定时任务触发(scheduled_tick)")
  165. # 无空闲并发槽则不派发;保持 tick 很快返回,避免阻塞调度器。
  166. if task_semaphore._value <= 0:
  167. logger.info("定时任务跳过:无空闲并发槽")
  168. return
  169. dt = _today_dt_int()
  170. item = get_one_today_unprocessed_demand(dt=dt)
  171. if not item:
  172. logger.info(f"定时任务跳过:无待处理需求(dt={dt} 或均已建任务)")
  173. return
  174. demand_content_id = item.get("demand_content_id")
  175. query = (item.get("query") or "").strip()
  176. suggestion = (item.get("suggestion") or "").strip()
  177. if demand_content_id is None or not query:
  178. logger.info("定时任务跳过:查询结果无效")
  179. return
  180. score = item.get("score")
  181. logger.info(
  182. f"定时任务领取(当天 score 最高):demand_content_id={demand_content_id}, "
  183. f"dt={dt}, score={score}"
  184. )
  185. create_task_record(demand_content_id)
  186. # 后台执行:由 execute_task 内部 semaphore 控制并发占用
  187. asyncio.create_task(
  188. execute_task(
  189. query=query,
  190. demand_id=demand_content_id,
  191. suggestion=suggestion,
  192. task_type="scheduled",
  193. )
  194. )
  195. async def run_startup_resume():
  196. """
  197. 启动后先执行 demand_find_task 中 status=执行中(1) 的任务(理论上仅一条)。
  198. """
  199. try:
  200. row = get_first_running_task()
  201. if not row:
  202. logger.info("启动恢复:无执行中(status=1)的 demand_find_task")
  203. return
  204. demand_content_id = row.get("demand_content_id")
  205. query = (row.get("query") or "").strip()
  206. suggestion = (row.get("suggestion") or "").strip()
  207. if demand_content_id is None or not query:
  208. logger.warning("启动恢复:执行中任务数据不完整,跳过")
  209. return
  210. logger.info(f"启动恢复:执行 demand_find_task status=1, demand_content_id={demand_content_id}")
  211. await execute_task(
  212. query=query,
  213. demand_id=int(demand_content_id),
  214. suggestion=suggestion,
  215. task_type="scheduled",
  216. )
  217. except Exception as e:
  218. logger.error(f"启动恢复失败: {e}", exc_info=True)
  219. # ============ API 接口 ============
  220. @app.post("/api/tasks", response_model=TaskResponse)
  221. async def create_task(request: TaskRequest):
  222. """
  223. 创建内容寻找任务
  224. Args:
  225. request.query: 查询内容(可选,不传则使用默认值)
  226. Returns:
  227. {
  228. "trace_id": "20260317_103046_xyz789",
  229. "status": "started",
  230. "query": "...",
  231. "message": "任务已启动,结果将保存到 .cache/traces/xxx/"
  232. }
  233. """
  234. # 获取 query、demand_id、suggestion(API 显式传入;与库表字段同名便于对齐)
  235. query = request.query or core.DEFAULT_QUERY
  236. demand_id = request.demand_id
  237. suggestion_str = (request.suggestion or "").strip()
  238. # 用 Event 等待 trace_id
  239. trace_id_ready = asyncio.Event()
  240. trace_id_holder = {"id": None}
  241. async def run_and_capture():
  242. try:
  243. # 获取第一个 Trace 对象来获取 trace_id
  244. from agent import Trace
  245. async with task_semaphore:
  246. # 重新构建 runner 来获取 trace_id
  247. from agent import AgentRunner, RunConfig, FileSystemTraceStore
  248. from agent.llm import create_openrouter_llm_call
  249. from agent.llm.prompts import SimplePrompt
  250. from agent.tools.builtin.knowledge import KnowledgeConfig
  251. prompt_path = Path(__file__).parent / "content_finder.md"
  252. prompt = SimplePrompt(prompt_path)
  253. trace_dir = os.getenv("TRACE_DIR", ".cache/traces")
  254. demand_id_str = str(demand_id) if demand_id is not None else ""
  255. messages = prompt.build_messages(
  256. query=query,
  257. suggestion=suggestion_str,
  258. trace_dir=trace_dir,
  259. demand_id=demand_id_str,
  260. )
  261. api_key = os.getenv("OPEN_ROUTER_API_KEY")
  262. model_name = prompt.config.get("model", "sonnet-4.6")
  263. model = os.getenv("MODEL", f"anthropic/claude-{model_name}")
  264. temperature = float(prompt.config.get("temperature", 0.3))
  265. max_iterations = int(os.getenv("MAX_ITERATIONS", "30"))
  266. trace_dir = os.getenv("TRACE_DIR", ".cache/traces")
  267. skills_dir = str(Path(__file__).parent / "skills")
  268. Path(trace_dir).mkdir(parents=True, exist_ok=True)
  269. store = FileSystemTraceStore(base_path=trace_dir)
  270. allowed_tools = [
  271. "douyin_search",
  272. "douyin_search_tikhub",
  273. "douyin_user_videos",
  274. "get_content_fans_portrait",
  275. "get_account_fans_portrait",
  276. "batch_fetch_portraits",
  277. "store_results_mysql",
  278. "exec_summary",
  279. ]
  280. runner = AgentRunner(
  281. llm_call=create_openrouter_llm_call(model=model),
  282. trace_store=store,
  283. skills_dir=skills_dir,
  284. )
  285. config = RunConfig(
  286. name="内容寻找",
  287. model=model,
  288. temperature=temperature,
  289. max_iterations=max_iterations,
  290. tools=allowed_tools,
  291. extra_llm_params={"max_tokens": 8192},
  292. knowledge=KnowledgeConfig(
  293. enable_extraction=True,
  294. enable_completion_extraction=True,
  295. enable_injection=True,
  296. owner="content_finder_agent",
  297. default_tags={"project": "content_finder"},
  298. default_scopes=["com.piaoquantv.supply"],
  299. default_search_types=["tool", "usecase", "definition"],
  300. default_search_owner="content_finder_agent"
  301. )
  302. )
  303. async for item in runner.run(messages=messages, config=config):
  304. if isinstance(item, Trace):
  305. if not trace_id_holder["id"]:
  306. trace_id_holder["id"] = item.trace_id
  307. trace_id_ready.set()
  308. logger.info(f"任务启动 [api]: trace_id={item.trace_id}")
  309. if item.status == "completed":
  310. stats["completed_tasks"] += 1
  311. logger.info(f"任务完成 [api]: trace_id={item.trace_id}")
  312. break
  313. elif item.status == "failed":
  314. stats["failed_tasks"] += 1
  315. logger.error(f"任务失败 [api]: trace_id={item.trace_id}, 错误={item.error_message}")
  316. break
  317. except Exception as e:
  318. stats["failed_tasks"] += 1
  319. logger.error(f"任务异常 [api]: {e}", exc_info=True)
  320. if not trace_id_holder["id"]:
  321. trace_id_holder["id"] = f"error_{datetime.now(SCHEDULER_TZ).strftime('%Y%m%d_%H%M%S')}"
  322. trace_id_ready.set()
  323. # 启动后台任务
  324. stats["total_tasks"] += 1
  325. asyncio.create_task(run_and_capture())
  326. # 等待 trace_id(最多 5 秒)
  327. try:
  328. await asyncio.wait_for(trace_id_ready.wait(), timeout=5.0)
  329. except asyncio.TimeoutError:
  330. logger.error("获取 trace_id 超时")
  331. raise HTTPException(status_code=500, detail="任务启动超时")
  332. trace_id = trace_id_holder["id"]
  333. return TaskResponse(
  334. trace_id=trace_id,
  335. status="started",
  336. query=query,
  337. message=f"任务已启动,结果将保存到 .cache/traces/{trace_id}/"
  338. )
  339. @app.get("/health")
  340. async def health_check():
  341. """健康检查"""
  342. return {
  343. "status": "ok",
  344. "max_concurrent_tasks": MAX_CONCURRENT_TASKS,
  345. "current_tasks": MAX_CONCURRENT_TASKS - task_semaphore._value,
  346. "scheduler_running": scheduler.running,
  347. "stats": stats
  348. }
  349. @app.get("/")
  350. async def root():
  351. """根路径"""
  352. return {
  353. "service": "内容寻找服务",
  354. "version": "1.0.0",
  355. "endpoints": {
  356. "create_task": "POST /api/tasks",
  357. "health": "GET /health"
  358. }
  359. }
  360. # ============ 启动事件 ============
  361. @app.on_event("startup")
  362. async def startup():
  363. """服务启动时初始化"""
  364. logger.info("=" * 60)
  365. logger.info("内容寻找服务启动中...")
  366. logger.info(f"最大并发任务数: {MAX_CONCURRENT_TASKS}")
  367. logger.info(f"定时器时区: {SCHEDULER_TIMEZONE}")
  368. logger.info(
  369. f"定时策略:每 {SCHEDULE_DISPATCH_INTERVAL_SECONDS} 秒尝试派发 1 条(有并发空槽才派发);"
  370. f"单次任务超时 {TASK_TIMEOUT_SECONDS}s"
  371. )
  372. asyncio.create_task(run_startup_resume())
  373. job = scheduler.add_job(
  374. scheduled_tick,
  375. "interval",
  376. seconds=SCHEDULE_DISPATCH_INTERVAL_SECONDS,
  377. misfire_grace_time=300,
  378. coalesce=True,
  379. max_instances=1,
  380. )
  381. scheduler.start()
  382. logger.info(f"定时任务已注册: id={job.id}, next_run_time={job.next_run_time}")
  383. logger.info("服务启动完成")
  384. logger.info("=" * 60)
  385. @app.on_event("shutdown")
  386. async def shutdown():
  387. """服务关闭时清理"""
  388. logger.info("服务关闭中...")
  389. if scheduler.running:
  390. scheduler.shutdown()
  391. logger.info("服务已关闭")
  392. # ============ 主函数 ============
  393. if __name__ == "__main__":
  394. import uvicorn
  395. port = int(os.getenv("PORT", "8080"))
  396. host = os.getenv("HOST", "0.0.0.0")
  397. logger.info(f"启动服务: http://{host}:{port}")
  398. uvicorn.run(app, host=host, port=port)