server.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. """
  2. 内容寻找服务
  3. 提供:
  4. 1. API 接口:POST /api/tasks - 触发内容寻找任务
  5. 2. 定时调度:启动后先恢复 demand_find_task 中 status=执行中 的任务;之后每 2 分钟轮询一次,
  6. 若当前无任务在执行,则从 demand_content 取当天(dt=YYYYMMDD)、未建任务记录且 score 最高的一条执行(不区分品类)
  7. 3. 并发控制:限制最大并发任务数;定时侧若已有任务在执行则跳过本次轮询
  8. 4. 单次寻找任务最长执行 15 分钟,超时记为失败并回写 demand_find_task
  9. """
  10. import asyncio
  11. import logging
  12. import os
  13. import uuid
  14. from datetime import datetime
  15. from pathlib import Path
  16. from typing import Optional
  17. import sys
  18. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  19. from fastapi import FastAPI, HTTPException
  20. from pydantic import BaseModel
  21. from apscheduler.schedulers.asyncio import AsyncIOScheduler
  22. from zoneinfo import ZoneInfo
  23. from dotenv import load_dotenv
  24. load_dotenv()
  25. import core
  26. from db import (
  27. create_task_record,
  28. get_first_running_task,
  29. get_one_today_unprocessed_demand,
  30. update_task_status,
  31. update_task_on_complete,
  32. )
  33. from db.schedule import STATUS_RUNNING, STATUS_SUCCESS, STATUS_FAILED
  34. # 配置日志
  35. log_dir = Path(__file__).parent / '.cache'
  36. log_dir.mkdir(exist_ok=True)
  37. logging.basicConfig(
  38. level=logging.INFO,
  39. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
  40. handlers=[
  41. logging.FileHandler(log_dir / 'server.log'),
  42. logging.StreamHandler()
  43. ]
  44. )
  45. logger = logging.getLogger(__name__)
  46. # FastAPI 应用
  47. app = FastAPI(
  48. title="内容寻找服务",
  49. version="1.0.0",
  50. description="抖音内容寻找 Agent 服务"
  51. )
  52. # 定时调度器(默认用中国时区,避免容器 UTC 导致错过预期时间点)
  53. SCHEDULER_TIMEZONE = os.getenv("SCHEDULER_TIMEZONE", os.getenv("TZ", "Asia/Shanghai"))
  54. scheduler = AsyncIOScheduler(timezone=ZoneInfo(SCHEDULER_TIMEZONE))
  55. # 并发控制
  56. MAX_CONCURRENT_TASKS = int(os.getenv("MAX_CONCURRENT_TASKS", "1"))
  57. task_semaphore = asyncio.Semaphore(MAX_CONCURRENT_TASKS)
  58. # 定时:轮询间隔(分钟)、单次任务超时(秒,默认 15 分钟)
  59. SCHEDULE_INTERVAL_MINUTES = int(os.getenv("SCHEDULE_INTERVAL_MINUTES", "2"))
  60. TASK_TIMEOUT_SECONDS = int(os.getenv("SCHEDULE_TASK_TIMEOUT_SECONDS", "900"))
  61. # 统计信息
  62. stats = {
  63. "total_tasks": 0,
  64. "completed_tasks": 0,
  65. "failed_tasks": 0,
  66. "scheduled_tasks": 0
  67. }
  68. # ============ 数据模型 ============
  69. class TaskRequest(BaseModel):
  70. query: Optional[str] = None
  71. demand_id: Optional[int] = None
  72. class TaskResponse(BaseModel):
  73. trace_id: str
  74. status: str
  75. query: str
  76. message: str
  77. # ============ 核心函数 ============
  78. def _update_scheduled_task_complete(demand_id: int, trace_id: str, status: int) -> None:
  79. """定时任务完成时更新 trace_id 和 status,静默处理异常"""
  80. try:
  81. update_task_on_complete(demand_id, trace_id, status)
  82. except Exception as e:
  83. logger.warning(f"更新任务状态失败: {e}")
  84. async def execute_task(
  85. query: str,
  86. demand_id: Optional[int] = None,
  87. task_type: str = "api",
  88. ):
  89. """
  90. 执行任务(带并发控制)
  91. Args:
  92. query: 查询内容
  93. demand_id: 需求 id(demand_content.id,关联 demand_content 表)
  94. task_type: 任务类型("api" 或 "scheduled")
  95. """
  96. async with task_semaphore:
  97. current_concurrent = MAX_CONCURRENT_TASKS - task_semaphore._value + 1
  98. logger.info(f"任务开始 [{task_type}]: query={query[:50]}..., 当前并发={current_concurrent}/{MAX_CONCURRENT_TASKS}")
  99. start_time = datetime.now()
  100. stats["total_tasks"] += 1
  101. if task_type == "scheduled":
  102. stats["scheduled_tasks"] += 1
  103. if task_type == "scheduled" and demand_id is not None:
  104. try:
  105. update_task_status("", demand_id, STATUS_RUNNING)
  106. except Exception as e:
  107. logger.warning(f"更新任务状态为执行中失败: {e}")
  108. try:
  109. result = await asyncio.wait_for(
  110. core.run_agent(
  111. query, demand_id=demand_id, stream_output=False, log_assistant_text=True
  112. ),
  113. timeout=float(TASK_TIMEOUT_SECONDS),
  114. )
  115. duration = (datetime.now() - start_time).total_seconds()
  116. if result["status"] == "completed":
  117. stats["completed_tasks"] += 1
  118. logger.info(f"任务完成 [{task_type}]: trace_id={result['trace_id']}, 耗时={duration:.1f}s")
  119. if task_type == "scheduled" and demand_id is not None:
  120. _update_scheduled_task_complete(demand_id, result["trace_id"], STATUS_SUCCESS)
  121. else:
  122. stats["failed_tasks"] += 1
  123. logger.error(f"任务失败 [{task_type}]: trace_id={result.get('trace_id')}, 错误={result.get('error')}, 耗时={duration:.1f}s")
  124. if task_type == "scheduled" and demand_id is not None:
  125. _update_scheduled_task_complete(demand_id, result.get("trace_id") or "", STATUS_FAILED)
  126. except asyncio.TimeoutError:
  127. stats["failed_tasks"] += 1
  128. duration = (datetime.now() - start_time).total_seconds()
  129. logger.error(
  130. f"任务超时 [{task_type}]: 超过 {TASK_TIMEOUT_SECONDS}s,记为失败, 耗时={duration:.1f}s"
  131. )
  132. if task_type == "scheduled" and demand_id is not None:
  133. _update_scheduled_task_complete(demand_id, "", STATUS_FAILED)
  134. except Exception as e:
  135. stats["failed_tasks"] += 1
  136. duration = (datetime.now() - start_time).total_seconds()
  137. logger.error(f"任务异常 [{task_type}]: {e}, 耗时={duration:.1f}s", exc_info=True)
  138. if task_type == "scheduled" and demand_id is not None:
  139. _update_scheduled_task_complete(demand_id, "", STATUS_FAILED)
  140. def _today_dt_int() -> int:
  141. """当天 demand_content.dt 约定为 YYYYMMDD 整数(如 20260402),与定时器时区一致。"""
  142. return int(datetime.now(ZoneInfo(SCHEDULER_TIMEZONE)).strftime("%Y%m%d"))
  143. def _has_running_content_task() -> bool:
  144. """
  145. 本进程内是否有内容寻找任务正在执行(占用并发槽)。
  146. 与 execute_task 共用 task_semaphore,含 API 触发与定时触发。
  147. """
  148. return task_semaphore._value != MAX_CONCURRENT_TASKS
  149. async def scheduled_tick():
  150. """
  151. 按 SCHEDULE_INTERVAL_MINUTES 轮询:若当前无任务在执行,则从 demand_content 取
  152. 当天(dt=今日)、尚未出现在 demand_find_task 中且 score 最高的一条需求并执行。
  153. """
  154. logger.info("定时任务触发(scheduled_tick)")
  155. if _has_running_content_task():
  156. logger.info("定时任务跳过:仍有任务在执行(并发槽占用中)")
  157. return
  158. dt = _today_dt_int()
  159. item = get_one_today_unprocessed_demand(dt=dt)
  160. if not item:
  161. logger.info(f"定时任务跳过:无待处理需求(dt={dt} 或均已建任务)")
  162. return
  163. demand_content_id = item.get("demand_content_id")
  164. query = (item.get("query") or "").strip()
  165. if demand_content_id is None or not query:
  166. logger.info("定时任务跳过:查询结果无效")
  167. return
  168. score = item.get("score")
  169. logger.info(
  170. f"定时任务领取(当天 score 最高):demand_content_id={demand_content_id}, "
  171. f"dt={dt}, score={score}"
  172. )
  173. create_task_record(demand_content_id)
  174. await execute_task(query=query, demand_id=demand_content_id, task_type="scheduled")
  175. async def run_startup_resume():
  176. """
  177. 启动后先执行 demand_find_task 中 status=执行中(1) 的任务(理论上仅一条)。
  178. """
  179. try:
  180. row = get_first_running_task()
  181. if not row:
  182. logger.info("启动恢复:无执行中(status=1)的 demand_find_task")
  183. return
  184. demand_content_id = row.get("demand_content_id")
  185. query = (row.get("query") or "").strip()
  186. if demand_content_id is None or not query:
  187. logger.warning("启动恢复:执行中任务数据不完整,跳过")
  188. return
  189. logger.info(f"启动恢复:执行 demand_find_task status=1, demand_content_id={demand_content_id}")
  190. await execute_task(query=query, demand_id=int(demand_content_id), task_type="scheduled")
  191. except Exception as e:
  192. logger.error(f"启动恢复失败: {e}", exc_info=True)
  193. # ============ API 接口 ============
  194. @app.post("/api/tasks", response_model=TaskResponse)
  195. async def create_task(request: TaskRequest):
  196. """
  197. 创建内容寻找任务
  198. Args:
  199. request.query: 查询内容(可选,不传则使用默认值)
  200. Returns:
  201. {
  202. "trace_id": "20260317_103046_xyz789",
  203. "status": "started",
  204. "query": "...",
  205. "message": "任务已启动,结果将保存到 .cache/traces/xxx/"
  206. }
  207. """
  208. # 获取 query 和 demand_id
  209. query = request.query or core.DEFAULT_QUERY
  210. demand_id = request.demand_id
  211. # 用 Event 等待 trace_id
  212. trace_id_ready = asyncio.Event()
  213. trace_id_holder = {"id": None}
  214. async def run_and_capture():
  215. try:
  216. # 获取第一个 Trace 对象来获取 trace_id
  217. from agent import Trace
  218. async with task_semaphore:
  219. # 重新构建 runner 来获取 trace_id
  220. from agent import AgentRunner, RunConfig, FileSystemTraceStore
  221. from agent.llm import create_openrouter_llm_call
  222. from agent.llm.prompts import SimplePrompt
  223. from agent.tools.builtin.knowledge import KnowledgeConfig
  224. prompt_path = Path(__file__).parent / "content_finder.md"
  225. prompt = SimplePrompt(prompt_path)
  226. trace_dir = os.getenv("TRACE_DIR", ".cache/traces")
  227. demand_id_str = str(demand_id) if demand_id is not None else ""
  228. messages = prompt.build_messages(query=query, trace_dir=trace_dir, demand_id=demand_id_str)
  229. api_key = os.getenv("OPEN_ROUTER_API_KEY")
  230. model_name = prompt.config.get("model", "sonnet-4.6")
  231. model = os.getenv("MODEL", f"anthropic/claude-{model_name}")
  232. temperature = float(prompt.config.get("temperature", 0.3))
  233. max_iterations = int(os.getenv("MAX_ITERATIONS", "30"))
  234. trace_dir = os.getenv("TRACE_DIR", ".cache/traces")
  235. skills_dir = str(Path(__file__).parent / "skills")
  236. Path(trace_dir).mkdir(parents=True, exist_ok=True)
  237. store = FileSystemTraceStore(base_path=trace_dir)
  238. allowed_tools = [
  239. "douyin_search",
  240. "douyin_search_tikhub",
  241. "douyin_user_videos",
  242. "get_content_fans_portrait",
  243. "get_account_fans_portrait",
  244. "batch_fetch_portraits",
  245. "store_results_mysql",
  246. "exec_summary",
  247. ]
  248. runner = AgentRunner(
  249. llm_call=create_openrouter_llm_call(model=model),
  250. trace_store=store,
  251. skills_dir=skills_dir,
  252. )
  253. config = RunConfig(
  254. name="内容寻找",
  255. model=model,
  256. temperature=temperature,
  257. max_iterations=max_iterations,
  258. tools=allowed_tools,
  259. extra_llm_params={"max_tokens": 8192},
  260. knowledge=KnowledgeConfig(
  261. enable_extraction=True,
  262. enable_completion_extraction=True,
  263. enable_injection=True,
  264. owner="content_finder_agent",
  265. default_tags={"project": "content_finder"},
  266. default_scopes=["com.piaoquantv.supply"],
  267. default_search_types=["tool", "usecase", "definition"],
  268. default_search_owner="content_finder_agent"
  269. )
  270. )
  271. async for item in runner.run(messages=messages, config=config):
  272. if isinstance(item, Trace):
  273. if not trace_id_holder["id"]:
  274. trace_id_holder["id"] = item.trace_id
  275. trace_id_ready.set()
  276. logger.info(f"任务启动 [api]: trace_id={item.trace_id}")
  277. if item.status == "completed":
  278. stats["completed_tasks"] += 1
  279. logger.info(f"任务完成 [api]: trace_id={item.trace_id}")
  280. break
  281. elif item.status == "failed":
  282. stats["failed_tasks"] += 1
  283. logger.error(f"任务失败 [api]: trace_id={item.trace_id}, 错误={item.error_message}")
  284. break
  285. except Exception as e:
  286. stats["failed_tasks"] += 1
  287. logger.error(f"任务异常 [api]: {e}", exc_info=True)
  288. if not trace_id_holder["id"]:
  289. trace_id_holder["id"] = f"error_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
  290. trace_id_ready.set()
  291. # 启动后台任务
  292. stats["total_tasks"] += 1
  293. asyncio.create_task(run_and_capture())
  294. # 等待 trace_id(最多 5 秒)
  295. try:
  296. await asyncio.wait_for(trace_id_ready.wait(), timeout=5.0)
  297. except asyncio.TimeoutError:
  298. logger.error("获取 trace_id 超时")
  299. raise HTTPException(status_code=500, detail="任务启动超时")
  300. trace_id = trace_id_holder["id"]
  301. return TaskResponse(
  302. trace_id=trace_id,
  303. status="started",
  304. query=query,
  305. message=f"任务已启动,结果将保存到 .cache/traces/{trace_id}/"
  306. )
  307. @app.get("/health")
  308. async def health_check():
  309. """健康检查"""
  310. return {
  311. "status": "ok",
  312. "max_concurrent_tasks": MAX_CONCURRENT_TASKS,
  313. "current_tasks": MAX_CONCURRENT_TASKS - task_semaphore._value,
  314. "scheduler_running": scheduler.running,
  315. "stats": stats
  316. }
  317. @app.get("/")
  318. async def root():
  319. """根路径"""
  320. return {
  321. "service": "内容寻找服务",
  322. "version": "1.0.0",
  323. "endpoints": {
  324. "create_task": "POST /api/tasks",
  325. "health": "GET /health"
  326. }
  327. }
  328. # ============ 启动事件 ============
  329. @app.on_event("startup")
  330. async def startup():
  331. """服务启动时初始化"""
  332. logger.info("=" * 60)
  333. logger.info("内容寻找服务启动中...")
  334. logger.info(f"最大并发任务数: {MAX_CONCURRENT_TASKS}")
  335. logger.info(f"定时器时区: {SCHEDULER_TIMEZONE}")
  336. logger.info(
  337. f"定时策略:每 {SCHEDULE_INTERVAL_MINUTES} 分钟检查是否空闲,空闲则取当天 score 最高的一条;"
  338. f"单次任务超时 {TASK_TIMEOUT_SECONDS}s"
  339. )
  340. asyncio.create_task(run_startup_resume())
  341. job = scheduler.add_job(
  342. scheduled_tick,
  343. "interval",
  344. minutes=SCHEDULE_INTERVAL_MINUTES,
  345. misfire_grace_time=300,
  346. coalesce=True,
  347. max_instances=1,
  348. )
  349. scheduler.start()
  350. logger.info(f"定时任务已注册: id={job.id}, next_run_time={job.next_run_time}")
  351. logger.info("服务启动完成")
  352. logger.info("=" * 60)
  353. @app.on_event("shutdown")
  354. async def shutdown():
  355. """服务关闭时清理"""
  356. logger.info("服务关闭中...")
  357. if scheduler.running:
  358. scheduler.shutdown()
  359. logger.info("服务已关闭")
  360. # ============ 主函数 ============
  361. if __name__ == "__main__":
  362. import uvicorn
  363. port = int(os.getenv("PORT", "8080"))
  364. host = os.getenv("HOST", "0.0.0.0")
  365. logger.info(f"启动服务: http://{host}:{port}")
  366. uvicorn.run(app, host=host, port=port)