web_api.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. """
  2. demand Web API(异步任务:发起 -> 立即返回 task_id -> 另一个接口查询状态)
  3. """
  4. import asyncio
  5. import os
  6. import importlib
  7. import sys
  8. from datetime import datetime, timedelta
  9. from pathlib import Path
  10. from typing import Any, Literal, Optional
  11. from fastapi import FastAPI, HTTPException
  12. from pydantic import BaseModel
  13. # 添加项目根目录到 Python 路径(与 run.py 保持一致)
  14. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  15. from examples.demand.changwen_prepare import changwen_prepare
  16. from examples.demand.mysql import mysql_db
  17. from examples.demand.piaoquan_prepare import piaoquan_prepare
  18. from examples.demand.run import _create_demand_task, main as run_demand
  19. app = FastAPI(title="demand web api")
  20. # APScheduler:使用动态导入避免环境未安装时直接导入失败
  21. try:
  22. _aps_asyncio_mod = importlib.import_module("apscheduler.schedulers.asyncio")
  23. _aps_cron_mod = importlib.import_module("apscheduler.triggers.cron")
  24. AsyncIOScheduler = getattr(_aps_asyncio_mod, "AsyncIOScheduler")
  25. CronTrigger = getattr(_aps_cron_mod, "CronTrigger")
  26. except Exception: # pragma: no cover
  27. AsyncIOScheduler = None # type: ignore[assignment]
  28. CronTrigger = None # type: ignore[assignment]
  29. class DemandStartRequest(BaseModel):
  30. cluster_name: str
  31. platform_type: Literal["piaoquan", "changwen"]
  32. # 定时任务配置:请按需修改/补齐
  33. # 说明:平台映射关系由 platform_type 决定;cluster_name 将用于匹配 demand_task.name
  34. DEMAND_SCHEDULE_CLUSTER_PLATFORM_LIST: list[dict] = []
  35. # 是否开启定时任务(可选,通过环境变量覆盖)
  36. DEMAND_SCHEDULER_ENABLED: bool = os.getenv("DEMAND_SCHEDULER_ENABLED", "1").strip() == "1"
  37. DEMAND_SCHEDULER_START_HOUR: int = 2
  38. def _get_today_time_window(now: datetime) -> tuple[datetime, datetime]:
  39. """返回今天的 [start, end) 时间窗口(本地时区)。"""
  40. start_of_today = datetime(year=now.year, month=now.month, day=now.day)
  41. end_of_today = start_of_today + timedelta(days=1)
  42. return start_of_today, end_of_today
  43. async def demand_start_sync(cluster_name: str, platform_type: Literal["piaoquan", "changwen"]) -> dict:
  44. """
  45. 与 /demand/start 同一执行链路,但不创建后台任务:prepare -> create demand_task -> 串行 await run_demand。
  46. """
  47. # prepare 阶段是同步的(当前示例代码为 sync),这里保持同步串行语义
  48. if platform_type == "piaoquan":
  49. execution_id = piaoquan_prepare(cluster_name)
  50. else:
  51. execution_id = changwen_prepare(cluster_name)
  52. if not execution_id:
  53. raise HTTPException(status_code=400, detail="获取 execution_id 失败")
  54. task_name = cluster_name[:32] if cluster_name else None
  55. task_id = _create_demand_task(
  56. execution_id=execution_id,
  57. name=task_name,
  58. platform=platform_type,
  59. )
  60. if not task_id:
  61. raise HTTPException(status_code=500, detail="创建 demand_task 失败")
  62. # run_once 内部 finally 会把 task 状态写回 MySQL
  63. result = await run_demand(
  64. cluster_name,
  65. platform_type,
  66. execution_id=execution_id,
  67. task_id=task_id,
  68. )
  69. return {"ok": True, "message": "调用成功", "task_id": task_id, "execution_id": execution_id, "result": result}
  70. def _today_has_status_0_or_1(cluster_name: str, platform_type: str, now: datetime) -> bool:
  71. """
  72. 查找 demand_task:
  73. - 限制为今天(create_time)
  74. - name 与 platform 精确匹配
  75. - 若存在 status 为 0 或 1 的记录,则跳过
  76. """
  77. start_of_today, end_of_today = _get_today_time_window(now)
  78. return mysql_db.exists(
  79. "demand_task",
  80. where=(
  81. "name = %s "
  82. "AND platform = %s "
  83. "AND status IN (0, 1) "
  84. "AND create_time >= %s "
  85. "AND create_time < %s"
  86. ),
  87. where_params=(
  88. str(cluster_name)[:32],
  89. str(platform_type)[:32],
  90. start_of_today,
  91. end_of_today,
  92. ),
  93. )
  94. async def demand_scheduled_run_once() -> None:
  95. """
  96. 任务批处理(串行):
  97. 遍历配置列表 -> 查当天 demand_task -> 匹配 cluster_name/name & platform_type/platform
  98. 若存在 status=0 或 1 的记录则跳过;否则执行一次 demand_start_sync。
  99. """
  100. if not DEMAND_SCHEDULE_CLUSTER_PLATFORM_LIST:
  101. return
  102. now = datetime.now()
  103. for item in DEMAND_SCHEDULE_CLUSTER_PLATFORM_LIST:
  104. cluster_name = item.get("cluster_name")
  105. platform_type = item.get("platform_type")
  106. if not cluster_name or platform_type not in ("piaoquan", "changwen"):
  107. continue
  108. if _today_has_status_0_or_1(cluster_name, platform_type, now=now):
  109. print(f"[scheduler] skip: cluster={cluster_name}, platform={platform_type} (today has status 0/1)")
  110. continue
  111. print(f"[scheduler] run: cluster={cluster_name}, platform={platform_type}")
  112. await demand_start_sync(cluster_name=cluster_name, platform_type=platform_type) # 串行执行
  113. _demand_scheduler: Optional[Any] = None
  114. _demand_scheduler_lock = asyncio.Lock()
  115. async def _demand_scheduler_job() -> None:
  116. """
  117. 定时任务 job:
  118. - 串行执行(防止并发)
  119. - 遍历配置 -> 今日过滤 demand_task -> 跳过/执行
  120. """
  121. if _demand_scheduler_lock.locked():
  122. return
  123. async with _demand_scheduler_lock:
  124. await demand_scheduled_run_once()
  125. @app.post("/demand/start")
  126. async def demand_start(req: DemandStartRequest):
  127. # 注意:这里会同步计算 execution_id(prepare 阶段),随后 run_once 放到后台异步执行。
  128. if req.platform_type == "piaoquan":
  129. execution_id = piaoquan_prepare(req.cluster_name)
  130. else:
  131. execution_id = changwen_prepare(req.cluster_name)
  132. if not execution_id:
  133. raise HTTPException(status_code=400, detail="获取 execution_id 失败")
  134. task_name = req.cluster_name[:32] if req.cluster_name else None
  135. task_id = _create_demand_task(
  136. execution_id=execution_id,
  137. name=task_name,
  138. platform=req.platform_type,
  139. )
  140. if not task_id:
  141. raise HTTPException(status_code=500, detail="创建 demand_task 失败")
  142. async def _job():
  143. # run_once 内部会在 finally 里把 task 状态写回 MySQL。
  144. await run_demand(
  145. req.cluster_name,
  146. req.platform_type,
  147. execution_id=execution_id,
  148. task_id=task_id,
  149. )
  150. asyncio.create_task(_job())
  151. return {"ok": True, "message": "调用成功", "task_id": task_id, "execution_id": execution_id}
  152. @app.on_event("startup")
  153. async def _start_demand_scheduler() -> None:
  154. """启动定时任务(cron 触发,2:00-24:00 每 30 分钟)。"""
  155. global _demand_scheduler
  156. if not DEMAND_SCHEDULER_ENABLED:
  157. return
  158. if _demand_scheduler is not None:
  159. return
  160. if AsyncIOScheduler is None or CronTrigger is None:
  161. # 依赖未安装则跳过定时任务
  162. print("[scheduler] apscheduler 未安装,跳过定时任务启动")
  163. return
  164. scheduler = AsyncIOScheduler()
  165. # 02:00 - 23:30:每 30 分钟一次
  166. scheduler.add_job(
  167. func=_demand_scheduler_job,
  168. trigger=CronTrigger(hour=f"{DEMAND_SCHEDULER_START_HOUR}-23", minute="0,30"),
  169. id="demand_scheduler_job_main",
  170. replace_existing=True,
  171. max_instances=1,
  172. coalesce=True,
  173. )
  174. # 24:00(即下一天 00:00):每天一次
  175. scheduler.add_job(
  176. func=_demand_scheduler_job,
  177. trigger=CronTrigger(hour="0", minute="0"),
  178. id="demand_scheduler_job_midnight",
  179. replace_existing=True,
  180. max_instances=1,
  181. coalesce=True,
  182. )
  183. scheduler.start()
  184. _demand_scheduler = scheduler
  185. @app.get("/demand/task/{task_id}/status")
  186. def demand_task_status(task_id: int, max_log_chars: int = 2000):
  187. row = mysql_db.select_one(
  188. "demand_task",
  189. columns="id, execution_id, name, platform, status, log",
  190. where="id = %s",
  191. where_params=(task_id,),
  192. )
  193. if not row:
  194. raise HTTPException(status_code=404, detail="task not found")
  195. status = int(row.get("status") or 0)
  196. status_map = {0: "running", 1: "completed", 2: "failed"}
  197. log_text = row.get("log") or ""
  198. if max_log_chars and isinstance(log_text, str) and len(log_text) > max_log_chars:
  199. log_text = log_text[:max_log_chars] + "...(truncated)"
  200. execution_id = row.get("execution_id")
  201. final_text: Optional[str] = None
  202. if status == 1 and execution_id:
  203. try:
  204. result_path = Path(__file__).parent / "output" / str(execution_id) / "result.txt"
  205. if result_path.exists():
  206. with open(result_path, "r", encoding="utf-8") as f:
  207. final_text = f.read()
  208. except Exception:
  209. final_text = None
  210. return {
  211. "task_id": task_id,
  212. "execution_id": execution_id,
  213. "name": row.get("name"),
  214. "platform": row.get("platform"),
  215. "status": status,
  216. "status_text": status_map.get(status, "unknown"),
  217. "final_text": final_text,
  218. "log": log_text,
  219. }
  220. @app.get("/demand/tasks")
  221. def demand_tasks(
  222. status: Optional[int] = None,
  223. name: Optional[str] = None,
  224. platform_type: Optional[str] = None,
  225. page: int = 1,
  226. page_size: int = 20,
  227. ):
  228. where_parts: list[str] = []
  229. where_params: list = []
  230. if status is not None:
  231. status_int = int(status)
  232. if status_int not in (0, 1, 2):
  233. raise HTTPException(status_code=400, detail="status 必须为 0/1/2")
  234. where_parts.append("status = %s")
  235. where_params.append(status_int)
  236. if name:
  237. name_str = str(name).strip()
  238. if name_str:
  239. # 支持模糊匹配:根据需求名称字段(varchar(32))
  240. where_parts.append("name LIKE %s")
  241. where_params.append(f"%{name_str}%")
  242. if platform_type:
  243. platform_str = str(platform_type).strip()
  244. if platform_str:
  245. where_parts.append("platform = %s")
  246. where_params.append(platform_str)
  247. where = " AND ".join(where_parts)
  248. params = tuple(where_params) if where_params else None
  249. data = mysql_db.paginate(
  250. "demand_task",
  251. page=page,
  252. page_size=page_size,
  253. columns="id, execution_id, name, platform, status, create_time, update_time",
  254. where=where,
  255. where_params=params,
  256. order_by="id DESC",
  257. )
  258. # 返回分页结构(data + pagination),便于前端直接展示
  259. return data
  260. def run_server():
  261. import uvicorn
  262. uvicorn.run(app, host="0.0.0.0", port=7000)
  263. if __name__ == "__main__":
  264. run_server()