api.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from __future__ import annotations
  2. import logging
  3. from typing import Any
  4. from fastapi import APIRouter, HTTPException, Query
  5. from pydantic import BaseModel, Field
  6. from gateway.core.executor.errors import ExecutorError, TaskNotFoundError
  7. from gateway.core.executor.models import RunMode
  8. from gateway.core.executor.task_manager import TaskManager
  9. from gateway.core.lifecycle.errors import LifecycleError
  10. logger = logging.getLogger(__name__)
  11. class SubmitTaskRequest(BaseModel):
  12. trace_id: str = Field(..., description="Agent Trace ID")
  13. task_description: str = Field(..., description="用户任务描述,将作为一条 user 消息续跑")
  14. mode: RunMode = Field("async", description="async:立即返回 task_id;sync:阻塞至 Trace 终态")
  15. metadata: dict[str, Any] = Field(default_factory=dict)
  16. class SubmitTaskResponse(BaseModel):
  17. task_id: str
  18. def build_executor_router(task_manager: TaskManager) -> APIRouter:
  19. router = APIRouter(prefix="/gateway/executor", tags=["gateway-executor"])
  20. @router.post("/tasks", response_model=SubmitTaskResponse)
  21. async def submit_task(req: SubmitTaskRequest) -> SubmitTaskResponse:
  22. try:
  23. task_id = await task_manager.submit_task(
  24. req.trace_id,
  25. req.task_description,
  26. mode=req.mode,
  27. metadata=req.metadata,
  28. )
  29. except TaskNotFoundError as e:
  30. raise HTTPException(status_code=404, detail=str(e)) from e
  31. except LifecycleError as e:
  32. raise HTTPException(status_code=404, detail=str(e)) from e
  33. except ExecutorError as e:
  34. raise HTTPException(status_code=400, detail=str(e)) from e
  35. except Exception as e:
  36. logger.exception("executor submit_task")
  37. raise HTTPException(status_code=502, detail=str(e)) from e
  38. return SubmitTaskResponse(task_id=task_id)
  39. @router.get("/tasks/{task_id}")
  40. async def get_task(task_id: str) -> dict[str, Any]:
  41. try:
  42. return task_manager.get_task(task_id)
  43. except TaskNotFoundError as e:
  44. raise HTTPException(status_code=404, detail=str(e)) from e
  45. @router.get("/tasks")
  46. async def list_tasks(
  47. trace_id: str | None = Query(None),
  48. status: str | None = Query(None),
  49. ) -> dict[str, Any]:
  50. items = task_manager.list_tasks(trace_id=trace_id, status=status)
  51. return {"tasks": items, "count": len(items)}
  52. @router.get("/tasks/{task_id}/logs")
  53. async def task_logs(task_id: str) -> dict[str, Any]:
  54. try:
  55. logs = task_manager.get_execution_logs(task_id)
  56. except TaskNotFoundError as e:
  57. raise HTTPException(status_code=404, detail=str(e)) from e
  58. return {"task_id": task_id, "logs": logs}
  59. @router.post("/tasks/{task_id}/cancel")
  60. async def cancel_task(task_id: str) -> dict[str, str]:
  61. try:
  62. await task_manager.cancel_task(task_id)
  63. except TaskNotFoundError as e:
  64. raise HTTPException(status_code=404, detail=str(e)) from e
  65. except ExecutorError as e:
  66. raise HTTPException(status_code=400, detail=str(e)) from e
  67. return {"task_id": task_id, "status": "cancel_requested"}
  68. return router