server.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. """
  2. 内容寻找服务
  3. 提供:
  4. 1. API 接口:POST /api/tasks - 触发内容寻找任务
  5. 2. 定时调度:启动后先恢复 demand_find_task 中 status=执行中 的任务;之后每 10 分钟从
  6. demand_content 取当天(dt=YYYYMMDD)且未建任务记录的 1 条需求执行(不区分品类)
  7. 3. 并发控制:限制最大并发任务数;定时侧若已有任务在执行则跳过本次轮询
  8. 4. 单次寻找任务最长执行 15 分钟,超时记为失败并回写 demand_find_task
  9. """
  10. import asyncio
  11. import logging
  12. import os
  13. import uuid
  14. from datetime import datetime
  15. from pathlib import Path
  16. from typing import Optional
  17. import sys
  18. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  19. from fastapi import FastAPI, HTTPException
  20. from pydantic import BaseModel
  21. from apscheduler.schedulers.asyncio import AsyncIOScheduler
  22. from zoneinfo import ZoneInfo
  23. from dotenv import load_dotenv
  24. load_dotenv()
  25. import core
  26. from db import (
  27. create_task_record,
  28. get_first_running_task,
  29. get_one_today_unprocessed_demand,
  30. update_task_status,
  31. update_task_on_complete,
  32. )
  33. from db.schedule import STATUS_RUNNING, STATUS_SUCCESS, STATUS_FAILED
  34. # 配置日志
  35. log_dir = Path(__file__).parent / '.cache'
  36. log_dir.mkdir(exist_ok=True)
  37. logging.basicConfig(
  38. level=logging.INFO,
  39. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
  40. handlers=[
  41. logging.FileHandler(log_dir / 'server.log'),
  42. logging.StreamHandler()
  43. ]
  44. )
  45. logger = logging.getLogger(__name__)
  46. # FastAPI 应用
  47. app = FastAPI(
  48. title="内容寻找服务",
  49. version="1.0.0",
  50. description="抖音内容寻找 Agent 服务"
  51. )
  52. # 定时调度器(默认用中国时区,避免容器 UTC 导致错过预期时间点)
  53. SCHEDULER_TIMEZONE = os.getenv("SCHEDULER_TIMEZONE", os.getenv("TZ", "Asia/Shanghai"))
  54. scheduler = AsyncIOScheduler(timezone=ZoneInfo(SCHEDULER_TIMEZONE))
  55. # 并发控制
  56. MAX_CONCURRENT_TASKS = int(os.getenv("MAX_CONCURRENT_TASKS", "1"))
  57. task_semaphore = asyncio.Semaphore(MAX_CONCURRENT_TASKS)
  58. # 定时:轮询间隔(分钟)、单次任务超时(秒,默认 15 分钟)
  59. SCHEDULE_INTERVAL_MINUTES = int(os.getenv("SCHEDULE_INTERVAL_MINUTES", "10"))
  60. TASK_TIMEOUT_SECONDS = int(os.getenv("SCHEDULE_TASK_TIMEOUT_SECONDS", "900"))
  61. # 统计信息
  62. stats = {
  63. "total_tasks": 0,
  64. "completed_tasks": 0,
  65. "failed_tasks": 0,
  66. "scheduled_tasks": 0
  67. }
  68. # ============ 数据模型 ============
  69. class TaskRequest(BaseModel):
  70. query: Optional[str] = None
  71. demand_id: Optional[int] = None
  72. class TaskResponse(BaseModel):
  73. trace_id: str
  74. status: str
  75. query: str
  76. message: str
  77. # ============ 核心函数 ============
  78. def _update_scheduled_task_complete(demand_id: int, trace_id: str, status: int) -> None:
  79. """定时任务完成时更新 trace_id 和 status,静默处理异常"""
  80. try:
  81. update_task_on_complete(demand_id, trace_id, status)
  82. except Exception as e:
  83. logger.warning(f"更新任务状态失败: {e}")
  84. async def execute_task(
  85. query: str,
  86. demand_id: Optional[int] = None,
  87. task_type: str = "api",
  88. ):
  89. """
  90. 执行任务(带并发控制)
  91. Args:
  92. query: 查询内容
  93. demand_id: 需求 id(demand_content.id,关联 demand_content 表)
  94. task_type: 任务类型("api" 或 "scheduled")
  95. """
  96. async with task_semaphore:
  97. current_concurrent = MAX_CONCURRENT_TASKS - task_semaphore._value + 1
  98. logger.info(f"任务开始 [{task_type}]: query={query[:50]}..., 当前并发={current_concurrent}/{MAX_CONCURRENT_TASKS}")
  99. start_time = datetime.now()
  100. stats["total_tasks"] += 1
  101. if task_type == "scheduled":
  102. stats["scheduled_tasks"] += 1
  103. if task_type == "scheduled" and demand_id is not None:
  104. try:
  105. update_task_status("", demand_id, STATUS_RUNNING)
  106. except Exception as e:
  107. logger.warning(f"更新任务状态为执行中失败: {e}")
  108. try:
  109. result = await asyncio.wait_for(
  110. core.run_agent(
  111. query, demand_id=demand_id, stream_output=False, log_assistant_text=True
  112. ),
  113. timeout=float(TASK_TIMEOUT_SECONDS),
  114. )
  115. duration = (datetime.now() - start_time).total_seconds()
  116. if result["status"] == "completed":
  117. stats["completed_tasks"] += 1
  118. logger.info(f"任务完成 [{task_type}]: trace_id={result['trace_id']}, 耗时={duration:.1f}s")
  119. if task_type == "scheduled" and demand_id is not None:
  120. _update_scheduled_task_complete(demand_id, result["trace_id"], STATUS_SUCCESS)
  121. else:
  122. stats["failed_tasks"] += 1
  123. logger.error(f"任务失败 [{task_type}]: trace_id={result.get('trace_id')}, 错误={result.get('error')}, 耗时={duration:.1f}s")
  124. if task_type == "scheduled" and demand_id is not None:
  125. _update_scheduled_task_complete(demand_id, result.get("trace_id") or "", STATUS_FAILED)
  126. except asyncio.TimeoutError:
  127. stats["failed_tasks"] += 1
  128. duration = (datetime.now() - start_time).total_seconds()
  129. logger.error(
  130. f"任务超时 [{task_type}]: 超过 {TASK_TIMEOUT_SECONDS}s,记为失败, 耗时={duration:.1f}s"
  131. )
  132. if task_type == "scheduled" and demand_id is not None:
  133. _update_scheduled_task_complete(demand_id, "", STATUS_FAILED)
  134. except Exception as e:
  135. stats["failed_tasks"] += 1
  136. duration = (datetime.now() - start_time).total_seconds()
  137. logger.error(f"任务异常 [{task_type}]: {e}, 耗时={duration:.1f}s", exc_info=True)
  138. if task_type == "scheduled" and demand_id is not None:
  139. _update_scheduled_task_complete(demand_id, "", STATUS_FAILED)
  140. def _today_dt_int() -> int:
  141. """当天 demand_content.dt 约定为 YYYYMMDD 整数(如 20260402),与定时器时区一致。"""
  142. return int(datetime.now(ZoneInfo(SCHEDULER_TIMEZONE)).strftime("%Y%m%d"))
  143. async def scheduled_tick():
  144. """
  145. 每 10 分钟执行一次:若当前无任务占用并发槽,则从 demand_content 取当天(dt=今日)
  146. 且尚未出现在 demand_find_task 中的 1 条需求并执行。
  147. """
  148. logger.info("定时任务触发(scheduled_tick)")
  149. if task_semaphore._value != MAX_CONCURRENT_TASKS:
  150. logger.info("定时任务跳过:仍有任务在执行(并发槽已满)")
  151. return
  152. dt = _today_dt_int()
  153. item = get_one_today_unprocessed_demand(dt=dt)
  154. if not item:
  155. logger.info(f"定时任务跳过:无待处理需求(dt={dt} 或均已建任务)")
  156. return
  157. demand_content_id = item.get("demand_content_id")
  158. query = (item.get("query") or "").strip()
  159. if demand_content_id is None or not query:
  160. logger.info("定时任务跳过:查询结果无效")
  161. return
  162. logger.info(f"定时任务领取:demand_content_id={demand_content_id}, dt={dt}")
  163. create_task_record(demand_content_id)
  164. await execute_task(query=query, demand_id=demand_content_id, task_type="scheduled")
  165. async def run_startup_resume():
  166. """
  167. 启动后先执行 demand_find_task 中 status=执行中(1) 的任务(理论上仅一条)。
  168. """
  169. try:
  170. row = get_first_running_task()
  171. if not row:
  172. logger.info("启动恢复:无执行中(status=1)的 demand_find_task")
  173. return
  174. demand_content_id = row.get("demand_content_id")
  175. query = (row.get("query") or "").strip()
  176. if demand_content_id is None or not query:
  177. logger.warning("启动恢复:执行中任务数据不完整,跳过")
  178. return
  179. logger.info(f"启动恢复:执行 demand_find_task status=1, demand_content_id={demand_content_id}")
  180. await execute_task(query=query, demand_id=int(demand_content_id), task_type="scheduled")
  181. except Exception as e:
  182. logger.error(f"启动恢复失败: {e}", exc_info=True)
  183. # ============ API 接口 ============
  184. @app.post("/api/tasks", response_model=TaskResponse)
  185. async def create_task(request: TaskRequest):
  186. """
  187. 创建内容寻找任务
  188. Args:
  189. request.query: 查询内容(可选,不传则使用默认值)
  190. Returns:
  191. {
  192. "trace_id": "20260317_103046_xyz789",
  193. "status": "started",
  194. "query": "...",
  195. "message": "任务已启动,结果将保存到 .cache/traces/xxx/"
  196. }
  197. """
  198. # 获取 query 和 demand_id
  199. query = request.query or core.DEFAULT_QUERY
  200. demand_id = request.demand_id
  201. # 用 Event 等待 trace_id
  202. trace_id_ready = asyncio.Event()
  203. trace_id_holder = {"id": None}
  204. async def run_and_capture():
  205. try:
  206. # 获取第一个 Trace 对象来获取 trace_id
  207. from agent import Trace
  208. async with task_semaphore:
  209. # 重新构建 runner 来获取 trace_id
  210. from agent import AgentRunner, RunConfig, FileSystemTraceStore
  211. from agent.llm import create_openrouter_llm_call
  212. from agent.llm.prompts import SimplePrompt
  213. from agent.tools.builtin.knowledge import KnowledgeConfig
  214. prompt_path = Path(__file__).parent / "content_finder.md"
  215. prompt = SimplePrompt(prompt_path)
  216. trace_dir = os.getenv("TRACE_DIR", ".cache/traces")
  217. demand_id_str = str(demand_id) if demand_id is not None else ""
  218. messages = prompt.build_messages(query=query, trace_dir=trace_dir, demand_id=demand_id_str)
  219. api_key = os.getenv("OPEN_ROUTER_API_KEY")
  220. model_name = prompt.config.get("model", "sonnet-4.6")
  221. model = os.getenv("MODEL", f"anthropic/claude-{model_name}")
  222. temperature = float(prompt.config.get("temperature", 0.3))
  223. max_iterations = int(os.getenv("MAX_ITERATIONS", "30"))
  224. trace_dir = os.getenv("TRACE_DIR", ".cache/traces")
  225. skills_dir = str(Path(__file__).parent / "skills")
  226. Path(trace_dir).mkdir(parents=True, exist_ok=True)
  227. store = FileSystemTraceStore(base_path=trace_dir)
  228. allowed_tools = [
  229. "douyin_search",
  230. "douyin_search_tikhub",
  231. "douyin_user_videos",
  232. "get_content_fans_portrait",
  233. "get_account_fans_portrait",
  234. "batch_fetch_portraits",
  235. "store_results_mysql",
  236. "exec_summary",
  237. ]
  238. runner = AgentRunner(
  239. llm_call=create_openrouter_llm_call(model=model),
  240. trace_store=store,
  241. skills_dir=skills_dir,
  242. )
  243. config = RunConfig(
  244. name="内容寻找",
  245. model=model,
  246. temperature=temperature,
  247. max_iterations=max_iterations,
  248. tools=allowed_tools,
  249. extra_llm_params={"max_tokens": 8192},
  250. knowledge=KnowledgeConfig(
  251. enable_extraction=True,
  252. enable_completion_extraction=True,
  253. enable_injection=True,
  254. owner="content_finder_agent",
  255. default_tags={"project": "content_finder"},
  256. default_scopes=["com.piaoquantv.supply"],
  257. default_search_types=["tool", "usecase", "definition"],
  258. default_search_owner="content_finder_agent"
  259. )
  260. )
  261. async for item in runner.run(messages=messages, config=config):
  262. if isinstance(item, Trace):
  263. if not trace_id_holder["id"]:
  264. trace_id_holder["id"] = item.trace_id
  265. trace_id_ready.set()
  266. logger.info(f"任务启动 [api]: trace_id={item.trace_id}")
  267. if item.status == "completed":
  268. stats["completed_tasks"] += 1
  269. logger.info(f"任务完成 [api]: trace_id={item.trace_id}")
  270. break
  271. elif item.status == "failed":
  272. stats["failed_tasks"] += 1
  273. logger.error(f"任务失败 [api]: trace_id={item.trace_id}, 错误={item.error_message}")
  274. break
  275. except Exception as e:
  276. stats["failed_tasks"] += 1
  277. logger.error(f"任务异常 [api]: {e}", exc_info=True)
  278. if not trace_id_holder["id"]:
  279. trace_id_holder["id"] = f"error_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
  280. trace_id_ready.set()
  281. # 启动后台任务
  282. stats["total_tasks"] += 1
  283. asyncio.create_task(run_and_capture())
  284. # 等待 trace_id(最多 5 秒)
  285. try:
  286. await asyncio.wait_for(trace_id_ready.wait(), timeout=5.0)
  287. except asyncio.TimeoutError:
  288. logger.error("获取 trace_id 超时")
  289. raise HTTPException(status_code=500, detail="任务启动超时")
  290. trace_id = trace_id_holder["id"]
  291. return TaskResponse(
  292. trace_id=trace_id,
  293. status="started",
  294. query=query,
  295. message=f"任务已启动,结果将保存到 .cache/traces/{trace_id}/"
  296. )
  297. @app.get("/health")
  298. async def health_check():
  299. """健康检查"""
  300. return {
  301. "status": "ok",
  302. "max_concurrent_tasks": MAX_CONCURRENT_TASKS,
  303. "current_tasks": MAX_CONCURRENT_TASKS - task_semaphore._value,
  304. "scheduler_running": scheduler.running,
  305. "stats": stats
  306. }
  307. @app.get("/")
  308. async def root():
  309. """根路径"""
  310. return {
  311. "service": "内容寻找服务",
  312. "version": "1.0.0",
  313. "endpoints": {
  314. "create_task": "POST /api/tasks",
  315. "health": "GET /health"
  316. }
  317. }
  318. # ============ 启动事件 ============
  319. @app.on_event("startup")
  320. async def startup():
  321. """服务启动时初始化"""
  322. logger.info("=" * 60)
  323. logger.info("内容寻找服务启动中...")
  324. logger.info(f"最大并发任务数: {MAX_CONCURRENT_TASKS}")
  325. logger.info(f"定时器时区: {SCHEDULER_TIMEZONE}")
  326. logger.info(
  327. f"定时策略:每 {SCHEDULE_INTERVAL_MINUTES} 分钟轮询当天需求;"
  328. f"单次任务超时 {TASK_TIMEOUT_SECONDS}s"
  329. )
  330. asyncio.create_task(run_startup_resume())
  331. job = scheduler.add_job(
  332. scheduled_tick,
  333. "interval",
  334. minutes=SCHEDULE_INTERVAL_MINUTES,
  335. misfire_grace_time=300,
  336. coalesce=True,
  337. max_instances=1,
  338. )
  339. scheduler.start()
  340. logger.info(f"定时任务已注册: id={job.id}, next_run_time={job.next_run_time}")
  341. logger.info("服务启动完成")
  342. logger.info("=" * 60)
  343. @app.on_event("shutdown")
  344. async def shutdown():
  345. """服务关闭时清理"""
  346. logger.info("服务关闭中...")
  347. if scheduler.running:
  348. scheduler.shutdown()
  349. logger.info("服务已关闭")
  350. # ============ 主函数 ============
  351. if __name__ == "__main__":
  352. import uvicorn
  353. port = int(os.getenv("PORT", "8080"))
  354. host = os.getenv("HOST", "0.0.0.0")
  355. logger.info(f"启动服务: http://{host}:{port}")
  356. uvicorn.run(app, host=host, port=port)