web_api.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. """
  2. demand Web API(异步任务:发起 -> 立即返回 task_id -> 另一个接口查询状态)
  3. """
  4. import asyncio
  5. import os
  6. import importlib
  7. import sys
  8. from datetime import datetime, timedelta
  9. from pathlib import Path
  10. from typing import Any, Literal, Optional
  11. from zoneinfo import ZoneInfo
  12. from fastapi import FastAPI, HTTPException
  13. from pydantic import BaseModel
  14. from examples.demand.db_manager import exist_cluster_tree
  15. # 添加项目根目录到 Python 路径(与 run.py 保持一致)
  16. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  17. from examples.demand.zengzhang_prepare import zengzhang_prepare
  18. from examples.demand.changwen_prepare import changwen_prepare
  19. from examples.demand.mysql import mysql_db
  20. from examples.demand.piaoquan_prepare import piaoquan_prepare
  21. from examples.demand.run import _create_demand_task, main as run_demand
  22. app = FastAPI(title="demand web api")
  23. # APScheduler:使用动态导入避免环境未安装时直接导入失败
  24. try:
  25. _aps_asyncio_mod = importlib.import_module("apscheduler.schedulers.asyncio")
  26. _aps_cron_mod = importlib.import_module("apscheduler.triggers.cron")
  27. AsyncIOScheduler = getattr(_aps_asyncio_mod, "AsyncIOScheduler")
  28. CronTrigger = getattr(_aps_cron_mod, "CronTrigger")
  29. except Exception: # pragma: no cover
  30. AsyncIOScheduler = None # type: ignore[assignment]
  31. CronTrigger = None # type: ignore[assignment]
  32. class DemandStartRequest(BaseModel):
  33. cluster_name: str
  34. platform_type: Literal["piaoquan", "changwen"]
  35. count: int
  36. def _get_demand_schedule_cluster_platform_list() -> list[dict]:
  37. """动态获取定时任务列表。"""
  38. try:
  39. # 延迟导入:避免服务启动时因 ODPS 依赖缺失直接失败
  40. from examples.demand.data_query_tools import get_demand_merge_level2_names
  41. return get_demand_merge_level2_names() or []
  42. except Exception as e: # pragma: no cover
  43. print(f"获取定时任务列表失败: {e}")
  44. return []
  45. # 是否开启定时任务(可选,通过环境变量覆盖)
  46. DEMAND_SCHEDULER_ENABLED: bool = os.getenv("DEMAND_SCHEDULER_ENABLED", "1").strip() == "1"
  47. DEMAND_SCHEDULER_START_HOUR: int = 2
  48. # 定时任务统一使用北京时间,避免服务器时区(如 UTC)带来的偏差
  49. BEIJING_TZ = ZoneInfo("Asia/Shanghai")
  50. def _get_today_time_window(now: datetime) -> tuple[datetime, datetime]:
  51. """返回今天的 [start, end) 时间窗口(本地时区)。"""
  52. start_of_today = datetime(year=now.year, month=now.month, day=now.day, tzinfo=now.tzinfo)
  53. end_of_today = start_of_today + timedelta(days=1)
  54. return start_of_today, end_of_today
  55. async def demand_start_sync(cluster_name: str, platform_type: Literal["piaoquan", "changwen"], count) -> dict:
  56. """
  57. 与 /demand/start 同一执行链路,但不创建后台任务:prepare -> create demand_task -> 串行 await run_demand。
  58. """
  59. # prepare 阶段是同步的(当前示例代码为 sync),这里保持同步串行语义
  60. execution_id = None
  61. task_id: Optional[int] = None
  62. try:
  63. if platform_type == "piaoquan":
  64. exist = exist_cluster_tree(cluster_name)
  65. if not exist:
  66. raise ValueError("获取聚类树失败")
  67. execution_id = piaoquan_prepare(cluster_name)
  68. elif platform_type == "changwen":
  69. execution_id = changwen_prepare(cluster_name)
  70. elif platform_type == "zengzhang":
  71. execution_id = zengzhang_prepare(cluster_name)
  72. if not execution_id:
  73. raise ValueError("获取 execution_id 失败")
  74. task_name = cluster_name[:32] if cluster_name else None
  75. task_id = _create_demand_task(
  76. execution_id=execution_id,
  77. name=task_name,
  78. platform=platform_type,
  79. )
  80. if not task_id:
  81. raise ValueError("创建 demand_task 失败")
  82. # run_once 内部 finally 会把 task 状态写回 MySQL
  83. result = await run_demand(
  84. cluster_name,
  85. platform_type,
  86. count,
  87. execution_id=execution_id,
  88. task_id=task_id,
  89. )
  90. return {"ok": True, "message": "调用成功", "task_id": task_id, "execution_id": execution_id, "result": result}
  91. except Exception as e:
  92. return {
  93. "ok": False,
  94. "message": f"执行失败: {e}",
  95. "task_id": task_id,
  96. "execution_id": execution_id,
  97. "result": None,
  98. }
  99. def _today_has_status_0_or_1(cluster_name: str, platform_type: str, now: datetime) -> bool:
  100. """
  101. 查找 demand_task:
  102. - 限制为今天(create_time)
  103. - name 与 platform 精确匹配
  104. - 若存在 status 为 0 或 1 的记录,则跳过
  105. """
  106. start_of_today, end_of_today = _get_today_time_window(now)
  107. # MySQL DATETIME 一般按无时区存储,这里使用北京时间对应的“本地时间”窗口做过滤
  108. start_of_today_naive = start_of_today.replace(tzinfo=None)
  109. end_of_today_naive = end_of_today.replace(tzinfo=None)
  110. return mysql_db.exists(
  111. "demand_task",
  112. where=(
  113. "name = %s "
  114. "AND platform = %s "
  115. "AND status IN (0, 1) "
  116. "AND create_time >= %s "
  117. "AND create_time < %s"
  118. ),
  119. where_params=(
  120. str(cluster_name)[:32],
  121. str(platform_type)[:32],
  122. start_of_today_naive,
  123. end_of_today_naive,
  124. ),
  125. )
  126. async def demand_scheduled_run_once() -> None:
  127. """
  128. 任务批处理(串行):
  129. 遍历配置列表 -> 查当天 demand_task -> 匹配 cluster_name/name & platform_type/platform
  130. 若存在 status=0 或 1 的记录则跳过;否则执行一次 demand_start_sync。
  131. """
  132. # ODPS 查询是阻塞调用,放到线程里避免阻塞事件循环
  133. demand_schedule_cluster_platform_list = await asyncio.to_thread(_get_demand_schedule_cluster_platform_list)
  134. if not demand_schedule_cluster_platform_list:
  135. return
  136. now = datetime.now(BEIJING_TZ)
  137. for item in demand_schedule_cluster_platform_list:
  138. cluster_name = item.get("cluster_name")
  139. platform_type = item.get("platform_type")
  140. count = item.get("count")
  141. if not cluster_name or platform_type not in ("piaoquan", "changwen"):
  142. continue
  143. if _today_has_status_0_or_1(cluster_name, platform_type, now=now):
  144. print(f"[scheduler] skip: cluster={cluster_name}, platform={platform_type} (today has status 0/1)")
  145. continue
  146. print(f"[scheduler] run: cluster={cluster_name}, platform={platform_type}, count={count}")
  147. await demand_start_sync(cluster_name=cluster_name, platform_type=platform_type, count=count) # 串行执行
  148. _demand_scheduler: Optional[Any] = None
  149. _demand_scheduler_lock = asyncio.Lock()
  150. async def _demand_scheduler_job() -> None:
  151. """
  152. 定时任务 job:
  153. - 串行执行(防止并发)
  154. - 遍历配置 -> 今日过滤 demand_task -> 跳过/执行
  155. """
  156. if _demand_scheduler_lock.locked():
  157. return
  158. # 指定日期跳过:北京时间 2026-04-08 不执行定时任务
  159. if datetime.now(BEIJING_TZ).date() == datetime(2026, 4, 8).date():
  160. print("[scheduler] skip: 2026-04-08")
  161. return
  162. async with _demand_scheduler_lock:
  163. await demand_scheduled_run_once()
  164. @app.post("/demand/start")
  165. async def demand_start(req: DemandStartRequest):
  166. # 注意:这里会同步计算 execution_id(prepare 阶段),随后 run_once 放到后台异步执行。
  167. if req.platform_type == "piaoquan":
  168. execution_id = piaoquan_prepare(req.cluster_name)
  169. else:
  170. execution_id = changwen_prepare(req.cluster_name)
  171. if not execution_id:
  172. raise HTTPException(status_code=400, detail="获取 execution_id 失败")
  173. task_name = req.cluster_name[:32] if req.cluster_name else None
  174. task_id = _create_demand_task(
  175. execution_id=execution_id,
  176. name=task_name,
  177. platform=req.platform_type,
  178. )
  179. if not task_id:
  180. raise HTTPException(status_code=500, detail="创建 demand_task 失败")
  181. async def _job():
  182. # run_once 内部会在 finally 里把 task 状态写回 MySQL。
  183. await run_demand(
  184. req.cluster_name,
  185. req.platform_type,
  186. req.count,
  187. execution_id=execution_id,
  188. task_id=task_id,
  189. )
  190. asyncio.create_task(_job())
  191. return {"ok": True, "message": "调用成功", "task_id": task_id, "execution_id": execution_id}
  192. @app.on_event("startup")
  193. async def _start_demand_scheduler() -> None:
  194. """启动定时任务(北京时间:2:00–22:30 每 30 分钟,不含 23 点与 0 点)。"""
  195. global _demand_scheduler
  196. if not DEMAND_SCHEDULER_ENABLED:
  197. return
  198. if _demand_scheduler is not None:
  199. return
  200. if AsyncIOScheduler is None or CronTrigger is None:
  201. # 依赖未安装则跳过定时任务
  202. print("[scheduler] apscheduler 未安装,跳过定时任务启动")
  203. return
  204. scheduler = AsyncIOScheduler(timezone=BEIJING_TZ)
  205. start_h = DEMAND_SCHEDULER_START_HOUR
  206. # DEMAND_SCHEDULER_START_HOUR–22:每 30 分钟(默认 2:00–22:30,不含 23 点)
  207. if start_h > 22:
  208. print(f"[scheduler] DEMAND_SCHEDULER_START_HOUR={start_h} > 22,无法注册 cron,跳过定时任务")
  209. else:
  210. scheduler.add_job(
  211. func=_demand_scheduler_job,
  212. trigger=CronTrigger(hour=f"{start_h}-22", minute="0,30", timezone=BEIJING_TZ),
  213. id="demand_scheduler_job_main",
  214. replace_existing=True,
  215. max_instances=1,
  216. coalesce=True,
  217. )
  218. scheduler.start()
  219. _demand_scheduler = scheduler
  220. @app.get("/demand/task/{task_id}/status")
  221. def demand_task_status(task_id: int, max_log_chars: int = 2000):
  222. row = mysql_db.select_one(
  223. "demand_task",
  224. columns="id, execution_id, name, platform, status, log",
  225. where="id = %s",
  226. where_params=(task_id,),
  227. )
  228. if not row:
  229. raise HTTPException(status_code=404, detail="task not found")
  230. status = int(row.get("status") or 0)
  231. status_map = {0: "running", 1: "completed", 2: "failed"}
  232. log_text = row.get("log") or ""
  233. if max_log_chars and isinstance(log_text, str) and len(log_text) > max_log_chars:
  234. log_text = log_text[:max_log_chars] + "...(truncated)"
  235. execution_id = row.get("execution_id")
  236. final_text: Optional[str] = None
  237. if status == 1 and execution_id:
  238. try:
  239. result_path = Path(__file__).parent / "output" / str(execution_id) / "result.txt"
  240. if result_path.exists():
  241. with open(result_path, "r", encoding="utf-8") as f:
  242. final_text = f.read()
  243. except Exception:
  244. final_text = None
  245. return {
  246. "task_id": task_id,
  247. "execution_id": execution_id,
  248. "name": row.get("name"),
  249. "platform": row.get("platform"),
  250. "status": status,
  251. "status_text": status_map.get(status, "unknown"),
  252. "final_text": final_text,
  253. "log": log_text,
  254. }
  255. @app.get("/demand/tasks")
  256. def demand_tasks(
  257. status: Optional[int] = None,
  258. name: Optional[str] = None,
  259. platform_type: Optional[str] = None,
  260. page: int = 1,
  261. page_size: int = 20,
  262. ):
  263. where_parts: list[str] = []
  264. where_params: list = []
  265. if status is not None:
  266. status_int = int(status)
  267. if status_int not in (0, 1, 2):
  268. raise HTTPException(status_code=400, detail="status 必须为 0/1/2")
  269. where_parts.append("status = %s")
  270. where_params.append(status_int)
  271. if name:
  272. name_str = str(name).strip()
  273. if name_str:
  274. # 支持模糊匹配:根据需求名称字段(varchar(32))
  275. where_parts.append("name LIKE %s")
  276. where_params.append(f"%{name_str}%")
  277. if platform_type:
  278. platform_str = str(platform_type).strip()
  279. if platform_str:
  280. where_parts.append("platform = %s")
  281. where_params.append(platform_str)
  282. where = " AND ".join(where_parts)
  283. params = tuple(where_params) if where_params else None
  284. data = mysql_db.paginate(
  285. "demand_task",
  286. page=page,
  287. page_size=page_size,
  288. columns="id, execution_id, name, platform, status, create_time, update_time",
  289. where=where,
  290. where_params=params,
  291. order_by="id DESC",
  292. )
  293. # 返回分页结构(data + pagination),便于前端直接展示
  294. return data
  295. def run_server():
  296. import uvicorn
  297. uvicorn.run(app, host="0.0.0.0", port=7000)
  298. if __name__ == "__main__":
  299. run_server()