|
|
@@ -0,0 +1,149 @@
|
|
|
+"""
|
|
|
+demand Web API(异步任务:发起 -> 立即返回 task_id -> 另一个接口查询状态)
|
|
|
+"""
|
|
|
+
|
|
|
+import asyncio
|
|
|
+import sys
|
|
|
+from pathlib import Path
|
|
|
+from typing import Literal, Optional
|
|
|
+
|
|
|
+from fastapi import FastAPI, HTTPException
|
|
|
+from pydantic import BaseModel
|
|
|
+
|
|
|
+# 添加项目根目录到 Python 路径(与 run.py 保持一致)
|
|
|
+sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
|
|
+
|
|
|
+from examples.demand.changwen_prepare import changwen_prepare
|
|
|
+from examples.demand.mysql import mysql_db
|
|
|
+from examples.demand.piaoquan_prepare import piaoquan_prepare
|
|
|
+from examples.demand.run import _create_demand_task, main
|
|
|
+
|
|
|
+app = FastAPI(title="demand web api")
|
|
|
+
|
|
|
+
|
|
|
+class DemandStartRequest(BaseModel):
|
|
|
+ cluster_name: str
|
|
|
+ platform_type: Literal["piaoquan", "changwen"]
|
|
|
+
|
|
|
+
|
|
|
+@app.post("/demand/start")
|
|
|
+async def demand_start(req: DemandStartRequest):
|
|
|
+ # 注意:这里会同步计算 execution_id(prepare 阶段),随后 run_once 放到后台异步执行。
|
|
|
+ if req.platform_type == "piaoquan":
|
|
|
+ execution_id = piaoquan_prepare(req.cluster_name)
|
|
|
+ else:
|
|
|
+ execution_id = changwen_prepare(req.cluster_name)
|
|
|
+
|
|
|
+ if not execution_id:
|
|
|
+ raise HTTPException(status_code=400, detail="获取 execution_id 失败")
|
|
|
+
|
|
|
+ task_name = req.cluster_name[:32] if req.cluster_name else None
|
|
|
+ task_id = _create_demand_task(
|
|
|
+ execution_id=execution_id,
|
|
|
+ name=task_name,
|
|
|
+ platform=req.platform_type,
|
|
|
+ )
|
|
|
+ if not task_id:
|
|
|
+ raise HTTPException(status_code=500, detail="创建 demand_task 失败")
|
|
|
+
|
|
|
+ async def _job():
|
|
|
+ # run_once 内部会在 finally 里把 task 状态写回 MySQL。
|
|
|
+ await main(
|
|
|
+ req.cluster_name,
|
|
|
+ req.platform_type,
|
|
|
+ execution_id=execution_id,
|
|
|
+ task_id=task_id,
|
|
|
+ )
|
|
|
+
|
|
|
+ asyncio.create_task(_job())
|
|
|
+ return {"ok": True, "message": "调用成功", "task_id": task_id, "execution_id": execution_id}
|
|
|
+
|
|
|
+
|
|
|
+@app.get("/demand/task/{task_id}/status")
|
|
|
+def demand_task_status(task_id: int, max_log_chars: int = 2000):
|
|
|
+ row = mysql_db.select_one(
|
|
|
+ "demand_task",
|
|
|
+ columns="id, execution_id, name, platform, status, log",
|
|
|
+ where="id = %s",
|
|
|
+ where_params=(task_id,),
|
|
|
+ )
|
|
|
+ if not row:
|
|
|
+ raise HTTPException(status_code=404, detail="task not found")
|
|
|
+
|
|
|
+ status = int(row.get("status") or 0)
|
|
|
+ status_map = {0: "running", 1: "completed", 2: "failed"}
|
|
|
+
|
|
|
+ log_text = row.get("log") or ""
|
|
|
+ if max_log_chars and isinstance(log_text, str) and len(log_text) > max_log_chars:
|
|
|
+ log_text = log_text[:max_log_chars] + "...(truncated)"
|
|
|
+
|
|
|
+ execution_id = row.get("execution_id")
|
|
|
+ final_text: Optional[str] = None
|
|
|
+ if status == 1 and execution_id:
|
|
|
+ try:
|
|
|
+ result_path = Path(__file__).parent / "output" / str(execution_id) / "result.txt"
|
|
|
+ if result_path.exists():
|
|
|
+ with open(result_path, "r", encoding="utf-8") as f:
|
|
|
+ final_text = f.read()
|
|
|
+ except Exception:
|
|
|
+ final_text = None
|
|
|
+
|
|
|
+ return {
|
|
|
+ "task_id": task_id,
|
|
|
+ "execution_id": execution_id,
|
|
|
+ "name": row.get("name"),
|
|
|
+ "platform": row.get("platform"),
|
|
|
+ "status": status,
|
|
|
+ "status_text": status_map.get(status, "unknown"),
|
|
|
+ "final_text": final_text,
|
|
|
+ "log": log_text,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+@app.get("/demand/tasks")
|
|
|
+def demand_tasks(
|
|
|
+ status: Optional[int] = None,
|
|
|
+ name: Optional[str] = None,
|
|
|
+ platform_type: Optional[str] = None,
|
|
|
+ page: int = 1,
|
|
|
+ page_size: int = 20,
|
|
|
+):
|
|
|
+ where_parts: list[str] = []
|
|
|
+ where_params: list = []
|
|
|
+
|
|
|
+ if status is not None:
|
|
|
+ status_int = int(status)
|
|
|
+ if status_int not in (0, 1, 2):
|
|
|
+ raise HTTPException(status_code=400, detail="status 必须为 0/1/2")
|
|
|
+ where_parts.append("status = %s")
|
|
|
+ where_params.append(status_int)
|
|
|
+
|
|
|
+ if name:
|
|
|
+ name_str = str(name).strip()
|
|
|
+ if name_str:
|
|
|
+ # 支持模糊匹配:根据需求名称字段(varchar(32))
|
|
|
+ where_parts.append("name LIKE %s")
|
|
|
+ where_params.append(f"%{name_str}%")
|
|
|
+
|
|
|
+ if platform_type:
|
|
|
+ platform_str = str(platform_type).strip()
|
|
|
+ if platform_str:
|
|
|
+ where_parts.append("platform = %s")
|
|
|
+ where_params.append(platform_str)
|
|
|
+
|
|
|
+ where = " AND ".join(where_parts)
|
|
|
+ params = tuple(where_params) if where_params else None
|
|
|
+
|
|
|
+ data = mysql_db.paginate(
|
|
|
+ "demand_task",
|
|
|
+ page=page,
|
|
|
+ page_size=page_size,
|
|
|
+ columns="id, execution_id, name, platform, status, create_time, update_time",
|
|
|
+ where=where,
|
|
|
+ where_params=params,
|
|
|
+ order_by="id DESC",
|
|
|
+ )
|
|
|
+
|
|
|
+ # 返回分页结构(data + pagination),便于前端直接展示
|
|
|
+ return data
|
|
|
+
|