web_api.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. """
  2. demand Web API(异步任务:发起 -> 立即返回 task_id -> 另一个接口查询状态)
  3. """
  4. import asyncio
  5. import sys
  6. from pathlib import Path
  7. from typing import Literal, Optional
  8. from fastapi import FastAPI, HTTPException
  9. from pydantic import BaseModel
  10. # 添加项目根目录到 Python 路径(与 run.py 保持一致)
  11. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  12. from examples.demand.changwen_prepare import changwen_prepare
  13. from examples.demand.mysql import mysql_db
  14. from examples.demand.piaoquan_prepare import piaoquan_prepare
  15. from examples.demand.run import _create_demand_task, main
  16. app = FastAPI(title="demand web api")
  17. class DemandStartRequest(BaseModel):
  18. cluster_name: str
  19. platform_type: Literal["piaoquan", "changwen"]
  20. @app.post("/demand/start")
  21. async def demand_start(req: DemandStartRequest):
  22. # 注意:这里会同步计算 execution_id(prepare 阶段),随后 run_once 放到后台异步执行。
  23. if req.platform_type == "piaoquan":
  24. execution_id = piaoquan_prepare(req.cluster_name)
  25. else:
  26. execution_id = changwen_prepare(req.cluster_name)
  27. if not execution_id:
  28. raise HTTPException(status_code=400, detail="获取 execution_id 失败")
  29. task_name = req.cluster_name[:32] if req.cluster_name else None
  30. task_id = _create_demand_task(
  31. execution_id=execution_id,
  32. name=task_name,
  33. platform=req.platform_type,
  34. )
  35. if not task_id:
  36. raise HTTPException(status_code=500, detail="创建 demand_task 失败")
  37. async def _job():
  38. # run_once 内部会在 finally 里把 task 状态写回 MySQL。
  39. await main(
  40. req.cluster_name,
  41. req.platform_type,
  42. execution_id=execution_id,
  43. task_id=task_id,
  44. )
  45. asyncio.create_task(_job())
  46. return {"ok": True, "message": "调用成功", "task_id": task_id, "execution_id": execution_id}
  47. @app.get("/demand/task/{task_id}/status")
  48. def demand_task_status(task_id: int, max_log_chars: int = 2000):
  49. row = mysql_db.select_one(
  50. "demand_task",
  51. columns="id, execution_id, name, platform, status, log",
  52. where="id = %s",
  53. where_params=(task_id,),
  54. )
  55. if not row:
  56. raise HTTPException(status_code=404, detail="task not found")
  57. status = int(row.get("status") or 0)
  58. status_map = {0: "running", 1: "completed", 2: "failed"}
  59. log_text = row.get("log") or ""
  60. if max_log_chars and isinstance(log_text, str) and len(log_text) > max_log_chars:
  61. log_text = log_text[:max_log_chars] + "...(truncated)"
  62. execution_id = row.get("execution_id")
  63. final_text: Optional[str] = None
  64. if status == 1 and execution_id:
  65. try:
  66. result_path = Path(__file__).parent / "output" / str(execution_id) / "result.txt"
  67. if result_path.exists():
  68. with open(result_path, "r", encoding="utf-8") as f:
  69. final_text = f.read()
  70. except Exception:
  71. final_text = None
  72. return {
  73. "task_id": task_id,
  74. "execution_id": execution_id,
  75. "name": row.get("name"),
  76. "platform": row.get("platform"),
  77. "status": status,
  78. "status_text": status_map.get(status, "unknown"),
  79. "final_text": final_text,
  80. "log": log_text,
  81. }
  82. @app.get("/demand/tasks")
  83. def demand_tasks(
  84. status: Optional[int] = None,
  85. name: Optional[str] = None,
  86. platform_type: Optional[str] = None,
  87. page: int = 1,
  88. page_size: int = 20,
  89. ):
  90. where_parts: list[str] = []
  91. where_params: list = []
  92. if status is not None:
  93. status_int = int(status)
  94. if status_int not in (0, 1, 2):
  95. raise HTTPException(status_code=400, detail="status 必须为 0/1/2")
  96. where_parts.append("status = %s")
  97. where_params.append(status_int)
  98. if name:
  99. name_str = str(name).strip()
  100. if name_str:
  101. # 支持模糊匹配:根据需求名称字段(varchar(32))
  102. where_parts.append("name LIKE %s")
  103. where_params.append(f"%{name_str}%")
  104. if platform_type:
  105. platform_str = str(platform_type).strip()
  106. if platform_str:
  107. where_parts.append("platform = %s")
  108. where_params.append(platform_str)
  109. where = " AND ".join(where_parts)
  110. params = tuple(where_params) if where_params else None
  111. data = mysql_db.paginate(
  112. "demand_task",
  113. page=page,
  114. page_size=page_size,
  115. columns="id, execution_id, name, platform, status, create_time, update_time",
  116. where=where,
  117. where_params=params,
  118. order_by="id DESC",
  119. )
  120. # 返回分页结构(data + pagination),便于前端直接展示
  121. return data