| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325 |
- """
- demand Web API(异步任务:发起 -> 立即返回 task_id -> 另一个接口查询状态)
- """
- import asyncio
- import os
- import importlib
- import sys
- from datetime import datetime, timedelta
- from pathlib import Path
- from typing import Any, Literal, Optional
- from fastapi import FastAPI, HTTPException
- from pydantic import BaseModel
- # 添加项目根目录到 Python 路径(与 run.py 保持一致)
- sys.path.insert(0, str(Path(__file__).parent.parent.parent))
- from examples.demand.changwen_prepare import changwen_prepare
- from examples.demand.mysql import mysql_db
- from examples.demand.piaoquan_prepare import piaoquan_prepare
- from examples.demand.run import _create_demand_task, main as run_demand
- app = FastAPI(title="demand web api")
- # APScheduler:使用动态导入避免环境未安装时直接导入失败
- try:
- _aps_asyncio_mod = importlib.import_module("apscheduler.schedulers.asyncio")
- _aps_cron_mod = importlib.import_module("apscheduler.triggers.cron")
- AsyncIOScheduler = getattr(_aps_asyncio_mod, "AsyncIOScheduler")
- CronTrigger = getattr(_aps_cron_mod, "CronTrigger")
- except Exception: # pragma: no cover
- AsyncIOScheduler = None # type: ignore[assignment]
- CronTrigger = None # type: ignore[assignment]
- class DemandStartRequest(BaseModel):
- cluster_name: str
- platform_type: Literal["piaoquan", "changwen"]
- # 定时任务配置:请按需修改/补齐
- # 说明:平台映射关系由 platform_type 决定;cluster_name 将用于匹配 demand_task.name
- DEMAND_SCHEDULE_CLUSTER_PLATFORM_LIST: list[dict] = []
- # 是否开启定时任务(可选,通过环境变量覆盖)
- DEMAND_SCHEDULER_ENABLED: bool = os.getenv("DEMAND_SCHEDULER_ENABLED", "1").strip() == "1"
- DEMAND_SCHEDULER_START_HOUR: int = 2
- def _get_today_time_window(now: datetime) -> tuple[datetime, datetime]:
- """返回今天的 [start, end) 时间窗口(本地时区)。"""
- start_of_today = datetime(year=now.year, month=now.month, day=now.day)
- end_of_today = start_of_today + timedelta(days=1)
- return start_of_today, end_of_today
- async def demand_start_sync(cluster_name: str, platform_type: Literal["piaoquan", "changwen"]) -> dict:
- """
- 与 /demand/start 同一执行链路,但不创建后台任务:prepare -> create demand_task -> 串行 await run_demand。
- """
- # prepare 阶段是同步的(当前示例代码为 sync),这里保持同步串行语义
- if platform_type == "piaoquan":
- execution_id = piaoquan_prepare(cluster_name)
- else:
- execution_id = changwen_prepare(cluster_name)
- if not execution_id:
- raise HTTPException(status_code=400, detail="获取 execution_id 失败")
- task_name = cluster_name[:32] if cluster_name else None
- task_id = _create_demand_task(
- execution_id=execution_id,
- name=task_name,
- platform=platform_type,
- )
- if not task_id:
- raise HTTPException(status_code=500, detail="创建 demand_task 失败")
- # run_once 内部 finally 会把 task 状态写回 MySQL
- result = await run_demand(
- cluster_name,
- platform_type,
- execution_id=execution_id,
- task_id=task_id,
- )
- return {"ok": True, "message": "调用成功", "task_id": task_id, "execution_id": execution_id, "result": result}
- def _today_has_status_0_or_1(cluster_name: str, platform_type: str, now: datetime) -> bool:
- """
- 查找 demand_task:
- - 限制为今天(create_time)
- - name 与 platform 精确匹配
- - 若存在 status 为 0 或 1 的记录,则跳过
- """
- start_of_today, end_of_today = _get_today_time_window(now)
- return mysql_db.exists(
- "demand_task",
- where=(
- "name = %s "
- "AND platform = %s "
- "AND status IN (0, 1) "
- "AND create_time >= %s "
- "AND create_time < %s"
- ),
- where_params=(
- str(cluster_name)[:32],
- str(platform_type)[:32],
- start_of_today,
- end_of_today,
- ),
- )
- async def demand_scheduled_run_once() -> None:
- """
- 任务批处理(串行):
- 遍历配置列表 -> 查当天 demand_task -> 匹配 cluster_name/name & platform_type/platform
- 若存在 status=0 或 1 的记录则跳过;否则执行一次 demand_start_sync。
- """
- if not DEMAND_SCHEDULE_CLUSTER_PLATFORM_LIST:
- return
- now = datetime.now()
- for item in DEMAND_SCHEDULE_CLUSTER_PLATFORM_LIST:
- cluster_name = item.get("cluster_name")
- platform_type = item.get("platform_type")
- if not cluster_name or platform_type not in ("piaoquan", "changwen"):
- continue
- if _today_has_status_0_or_1(cluster_name, platform_type, now=now):
- print(f"[scheduler] skip: cluster={cluster_name}, platform={platform_type} (today has status 0/1)")
- continue
- print(f"[scheduler] run: cluster={cluster_name}, platform={platform_type}")
- await demand_start_sync(cluster_name=cluster_name, platform_type=platform_type) # 串行执行
- _demand_scheduler: Optional[Any] = None
- _demand_scheduler_lock = asyncio.Lock()
- async def _demand_scheduler_job() -> None:
- """
- 定时任务 job:
- - 串行执行(防止并发)
- - 遍历配置 -> 今日过滤 demand_task -> 跳过/执行
- """
- if _demand_scheduler_lock.locked():
- return
- async with _demand_scheduler_lock:
- await demand_scheduled_run_once()
- @app.post("/demand/start")
- async def demand_start(req: DemandStartRequest):
- # 注意:这里会同步计算 execution_id(prepare 阶段),随后 run_once 放到后台异步执行。
- if req.platform_type == "piaoquan":
- execution_id = piaoquan_prepare(req.cluster_name)
- else:
- execution_id = changwen_prepare(req.cluster_name)
- if not execution_id:
- raise HTTPException(status_code=400, detail="获取 execution_id 失败")
- task_name = req.cluster_name[:32] if req.cluster_name else None
- task_id = _create_demand_task(
- execution_id=execution_id,
- name=task_name,
- platform=req.platform_type,
- )
- if not task_id:
- raise HTTPException(status_code=500, detail="创建 demand_task 失败")
- async def _job():
- # run_once 内部会在 finally 里把 task 状态写回 MySQL。
- await run_demand(
- req.cluster_name,
- req.platform_type,
- execution_id=execution_id,
- task_id=task_id,
- )
- asyncio.create_task(_job())
- return {"ok": True, "message": "调用成功", "task_id": task_id, "execution_id": execution_id}
- @app.on_event("startup")
- async def _start_demand_scheduler() -> None:
- """启动定时任务(cron 触发,2:00-24:00 每 30 分钟)。"""
- global _demand_scheduler
- if not DEMAND_SCHEDULER_ENABLED:
- return
- if _demand_scheduler is not None:
- return
- if AsyncIOScheduler is None or CronTrigger is None:
- # 依赖未安装则跳过定时任务
- print("[scheduler] apscheduler 未安装,跳过定时任务启动")
- return
- scheduler = AsyncIOScheduler()
- # 02:00 - 23:30:每 30 分钟一次
- scheduler.add_job(
- func=_demand_scheduler_job,
- trigger=CronTrigger(hour=f"{DEMAND_SCHEDULER_START_HOUR}-23", minute="0,30"),
- id="demand_scheduler_job_main",
- replace_existing=True,
- max_instances=1,
- coalesce=True,
- )
- # 24:00(即下一天 00:00):每天一次
- scheduler.add_job(
- func=_demand_scheduler_job,
- trigger=CronTrigger(hour="0", minute="0"),
- id="demand_scheduler_job_midnight",
- replace_existing=True,
- max_instances=1,
- coalesce=True,
- )
- scheduler.start()
- _demand_scheduler = scheduler
- @app.get("/demand/task/{task_id}/status")
- def demand_task_status(task_id: int, max_log_chars: int = 2000):
- row = mysql_db.select_one(
- "demand_task",
- columns="id, execution_id, name, platform, status, log",
- where="id = %s",
- where_params=(task_id,),
- )
- if not row:
- raise HTTPException(status_code=404, detail="task not found")
- status = int(row.get("status") or 0)
- status_map = {0: "running", 1: "completed", 2: "failed"}
- log_text = row.get("log") or ""
- if max_log_chars and isinstance(log_text, str) and len(log_text) > max_log_chars:
- log_text = log_text[:max_log_chars] + "...(truncated)"
- execution_id = row.get("execution_id")
- final_text: Optional[str] = None
- if status == 1 and execution_id:
- try:
- result_path = Path(__file__).parent / "output" / str(execution_id) / "result.txt"
- if result_path.exists():
- with open(result_path, "r", encoding="utf-8") as f:
- final_text = f.read()
- except Exception:
- final_text = None
- return {
- "task_id": task_id,
- "execution_id": execution_id,
- "name": row.get("name"),
- "platform": row.get("platform"),
- "status": status,
- "status_text": status_map.get(status, "unknown"),
- "final_text": final_text,
- "log": log_text,
- }
- @app.get("/demand/tasks")
- def demand_tasks(
- status: Optional[int] = None,
- name: Optional[str] = None,
- platform_type: Optional[str] = None,
- page: int = 1,
- page_size: int = 20,
- ):
- where_parts: list[str] = []
- where_params: list = []
- if status is not None:
- status_int = int(status)
- if status_int not in (0, 1, 2):
- raise HTTPException(status_code=400, detail="status 必须为 0/1/2")
- where_parts.append("status = %s")
- where_params.append(status_int)
- if name:
- name_str = str(name).strip()
- if name_str:
- # 支持模糊匹配:根据需求名称字段(varchar(32))
- where_parts.append("name LIKE %s")
- where_params.append(f"%{name_str}%")
- if platform_type:
- platform_str = str(platform_type).strip()
- if platform_str:
- where_parts.append("platform = %s")
- where_params.append(platform_str)
- where = " AND ".join(where_parts)
- params = tuple(where_params) if where_params else None
- data = mysql_db.paginate(
- "demand_task",
- page=page,
- page_size=page_size,
- columns="id, execution_id, name, platform, status, create_time, update_time",
- where=where,
- where_params=params,
- order_by="id DESC",
- )
- # 返回分页结构(data + pagination),便于前端直接展示
- return data
- def run_server():
- import uvicorn
- uvicorn.run(app, host="0.0.0.0", port=7000)
- if __name__ == "__main__":
- run_server()
|