| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149 |
- """
- demand Web API(异步任务:发起 -> 立即返回 task_id -> 另一个接口查询状态)
- """
- import asyncio
- import sys
- from pathlib import Path
- from typing import Literal, Optional
- from fastapi import FastAPI, HTTPException
- from pydantic import BaseModel
- # 添加项目根目录到 Python 路径(与 run.py 保持一致)
- sys.path.insert(0, str(Path(__file__).parent.parent.parent))
- from examples.demand.changwen_prepare import changwen_prepare
- from examples.demand.mysql import mysql_db
- from examples.demand.piaoquan_prepare import piaoquan_prepare
- from examples.demand.run import _create_demand_task, main
- app = FastAPI(title="demand web api")
- class DemandStartRequest(BaseModel):
- cluster_name: str
- platform_type: Literal["piaoquan", "changwen"]
- @app.post("/demand/start")
- async def demand_start(req: DemandStartRequest):
- # 注意:这里会同步计算 execution_id(prepare 阶段),随后 run_once 放到后台异步执行。
- if req.platform_type == "piaoquan":
- execution_id = piaoquan_prepare(req.cluster_name)
- else:
- execution_id = changwen_prepare(req.cluster_name)
- if not execution_id:
- raise HTTPException(status_code=400, detail="获取 execution_id 失败")
- task_name = req.cluster_name[:32] if req.cluster_name else None
- task_id = _create_demand_task(
- execution_id=execution_id,
- name=task_name,
- platform=req.platform_type,
- )
- if not task_id:
- raise HTTPException(status_code=500, detail="创建 demand_task 失败")
- async def _job():
- # run_once 内部会在 finally 里把 task 状态写回 MySQL。
- await main(
- req.cluster_name,
- req.platform_type,
- execution_id=execution_id,
- task_id=task_id,
- )
- asyncio.create_task(_job())
- return {"ok": True, "message": "调用成功", "task_id": task_id, "execution_id": execution_id}
- @app.get("/demand/task/{task_id}/status")
- def demand_task_status(task_id: int, max_log_chars: int = 2000):
- row = mysql_db.select_one(
- "demand_task",
- columns="id, execution_id, name, platform, status, log",
- where="id = %s",
- where_params=(task_id,),
- )
- if not row:
- raise HTTPException(status_code=404, detail="task not found")
- status = int(row.get("status") or 0)
- status_map = {0: "running", 1: "completed", 2: "failed"}
- log_text = row.get("log") or ""
- if max_log_chars and isinstance(log_text, str) and len(log_text) > max_log_chars:
- log_text = log_text[:max_log_chars] + "...(truncated)"
- execution_id = row.get("execution_id")
- final_text: Optional[str] = None
- if status == 1 and execution_id:
- try:
- result_path = Path(__file__).parent / "output" / str(execution_id) / "result.txt"
- if result_path.exists():
- with open(result_path, "r", encoding="utf-8") as f:
- final_text = f.read()
- except Exception:
- final_text = None
- return {
- "task_id": task_id,
- "execution_id": execution_id,
- "name": row.get("name"),
- "platform": row.get("platform"),
- "status": status,
- "status_text": status_map.get(status, "unknown"),
- "final_text": final_text,
- "log": log_text,
- }
- @app.get("/demand/tasks")
- def demand_tasks(
- status: Optional[int] = None,
- name: Optional[str] = None,
- platform_type: Optional[str] = None,
- page: int = 1,
- page_size: int = 20,
- ):
- where_parts: list[str] = []
- where_params: list = []
- if status is not None:
- status_int = int(status)
- if status_int not in (0, 1, 2):
- raise HTTPException(status_code=400, detail="status 必须为 0/1/2")
- where_parts.append("status = %s")
- where_params.append(status_int)
- if name:
- name_str = str(name).strip()
- if name_str:
- # 支持模糊匹配:根据需求名称字段(varchar(32))
- where_parts.append("name LIKE %s")
- where_params.append(f"%{name_str}%")
- if platform_type:
- platform_str = str(platform_type).strip()
- if platform_str:
- where_parts.append("platform = %s")
- where_params.append(platform_str)
- where = " AND ".join(where_parts)
- params = tuple(where_params) if where_params else None
- data = mysql_db.paginate(
- "demand_task",
- page=page,
- page_size=page_size,
- columns="id, execution_id, name, platform, status, create_time, update_time",
- where=where,
- where_params=params,
- order_by="id DESC",
- )
- # 返回分页结构(data + pagination),便于前端直接展示
- return data
|