""" demand Web API(异步任务:发起 -> 立即返回 task_id -> 另一个接口查询状态) """ import asyncio import sys from pathlib import Path from typing import Literal, Optional from fastapi import FastAPI, HTTPException from pydantic import BaseModel # 添加项目根目录到 Python 路径(与 run.py 保持一致) sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from examples.demand.changwen_prepare import changwen_prepare from examples.demand.mysql import mysql_db from examples.demand.piaoquan_prepare import piaoquan_prepare from examples.demand.run import _create_demand_task, main app = FastAPI(title="demand web api") class DemandStartRequest(BaseModel): cluster_name: str platform_type: Literal["piaoquan", "changwen"] @app.post("/demand/start") async def demand_start(req: DemandStartRequest): # 注意:这里会同步计算 execution_id(prepare 阶段),随后 run_once 放到后台异步执行。 if req.platform_type == "piaoquan": execution_id = piaoquan_prepare(req.cluster_name) else: execution_id = changwen_prepare(req.cluster_name) if not execution_id: raise HTTPException(status_code=400, detail="获取 execution_id 失败") task_name = req.cluster_name[:32] if req.cluster_name else None task_id = _create_demand_task( execution_id=execution_id, name=task_name, platform=req.platform_type, ) if not task_id: raise HTTPException(status_code=500, detail="创建 demand_task 失败") async def _job(): # run_once 内部会在 finally 里把 task 状态写回 MySQL。 await main( req.cluster_name, req.platform_type, execution_id=execution_id, task_id=task_id, ) asyncio.create_task(_job()) return {"ok": True, "message": "调用成功", "task_id": task_id, "execution_id": execution_id} @app.get("/demand/task/{task_id}/status") def demand_task_status(task_id: int, max_log_chars: int = 2000): row = mysql_db.select_one( "demand_task", columns="id, execution_id, name, platform, status, log", where="id = %s", where_params=(task_id,), ) if not row: raise HTTPException(status_code=404, detail="task not found") status = int(row.get("status") or 0) status_map = {0: "running", 1: "completed", 2: "failed"} log_text = row.get("log") or "" if max_log_chars and isinstance(log_text, str) and len(log_text) > max_log_chars: log_text = log_text[:max_log_chars] + "...(truncated)" execution_id = row.get("execution_id") final_text: Optional[str] = None if status == 1 and execution_id: try: result_path = Path(__file__).parent / "output" / str(execution_id) / "result.txt" if result_path.exists(): with open(result_path, "r", encoding="utf-8") as f: final_text = f.read() except Exception: final_text = None return { "task_id": task_id, "execution_id": execution_id, "name": row.get("name"), "platform": row.get("platform"), "status": status, "status_text": status_map.get(status, "unknown"), "final_text": final_text, "log": log_text, } @app.get("/demand/tasks") def demand_tasks( status: Optional[int] = None, name: Optional[str] = None, platform_type: Optional[str] = None, page: int = 1, page_size: int = 20, ): where_parts: list[str] = [] where_params: list = [] if status is not None: status_int = int(status) if status_int not in (0, 1, 2): raise HTTPException(status_code=400, detail="status 必须为 0/1/2") where_parts.append("status = %s") where_params.append(status_int) if name: name_str = str(name).strip() if name_str: # 支持模糊匹配:根据需求名称字段(varchar(32)) where_parts.append("name LIKE %s") where_params.append(f"%{name_str}%") if platform_type: platform_str = str(platform_type).strip() if platform_str: where_parts.append("platform = %s") where_params.append(platform_str) where = " AND ".join(where_parts) params = tuple(where_params) if where_params else None data = mysql_db.paginate( "demand_task", page=page, page_size=page_size, columns="id, execution_id, name, platform, status, create_time, update_time", where=where, where_params=params, order_by="id DESC", ) # 返回分页结构(data + pagination),便于前端直接展示 return data