task_manager.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. from __future__ import annotations
  2. import asyncio
  3. import logging
  4. import uuid
  5. from collections import defaultdict
  6. from collections.abc import Callable
  7. from typing import Any, Literal
  8. import httpx
  9. from utils.env_parse import env_float, env_str
  10. from gateway.core.executor.errors import ExecutorError, TaskNotFoundError
  11. from gateway.core.executor.models import TERMINAL_STATUSES, RunMode, TaskRecord
  12. from gateway.core.lifecycle import TraceManager, WorkspaceManager
  13. from gateway.core.lifecycle.errors import LifecycleError
  14. logger = logging.getLogger(__name__)
  15. # --- 执行上下文(路径 + 内存日志) ---
  16. class ExecutionContext:
  17. """Workspace 路径解析与执行日志(内存)。"""
  18. def __init__(self, workspace_manager: WorkspaceManager, trace_manager: TraceManager) -> None:
  19. self._wm = workspace_manager
  20. self._tm = trace_manager
  21. self._logs: dict[str, list[dict[str, Any]]] = defaultdict(list)
  22. async def create_context(self, task_id: str, trace_id: str) -> dict[str, Any]:
  23. try:
  24. workspace_id = self._tm.get_workspace_id(trace_id)
  25. except LifecycleError as e:
  26. logger.warning("ExecutionContext.create_context: %s", e)
  27. return {
  28. "task_id": task_id,
  29. "trace_id": trace_id,
  30. "workspace_id": None,
  31. "workspace_path": None,
  32. "note": str(e),
  33. }
  34. try:
  35. path = await self._wm.get_workspace_path(workspace_id)
  36. except Exception as e:
  37. logger.warning("ExecutionContext.get_workspace_path failed: %s", e)
  38. path = None
  39. return {
  40. "task_id": task_id,
  41. "trace_id": trace_id,
  42. "workspace_id": workspace_id,
  43. "workspace_path": path,
  44. }
  45. async def get_workspace_path(self, trace_id: str) -> str | None:
  46. try:
  47. workspace_id = self._tm.get_workspace_id(trace_id)
  48. return await self._wm.get_workspace_path(workspace_id)
  49. except Exception as e:
  50. logger.warning("ExecutionContext.get_workspace_path trace_id=%s err=%s", trace_id, e)
  51. return None
  52. def log_execution(self, task_id: str, log_entry: dict[str, Any]) -> None:
  53. self._logs[task_id].append(dict(log_entry))
  54. def get_logs(self, task_id: str) -> list[dict[str, Any]]:
  55. return list(self._logs.get(task_id, []))
  56. # --- 内存任务表、按 trace 互斥 ---
  57. class _InMemoryTaskStore:
  58. def __init__(self) -> None:
  59. self._by_id: dict[str, TaskRecord] = {}
  60. def put(self, record: TaskRecord) -> None:
  61. self._by_id[record.task_id] = record
  62. def get(self, task_id: str) -> TaskRecord | None:
  63. return self._by_id.get(task_id)
  64. def require(self, task_id: str) -> TaskRecord:
  65. rec = self._by_id.get(task_id)
  66. if rec is None:
  67. raise TaskNotFoundError(task_id)
  68. return rec
  69. def list_as_dicts(
  70. self,
  71. *,
  72. trace_id: str | None = None,
  73. status: str | None = None,
  74. ) -> list[dict[str, Any]]:
  75. out: list[dict[str, Any]] = []
  76. for rec in self._by_id.values():
  77. if trace_id is not None and rec.trace_id != trace_id:
  78. continue
  79. if status is not None and rec.status != status:
  80. continue
  81. out.append(rec.to_dict())
  82. out.sort(key=lambda d: d.get("created_at") or "", reverse=True)
  83. return out
  84. def contains(self, task_id: str) -> bool:
  85. return task_id in self._by_id
  86. class _TraceSerialLocks:
  87. def __init__(self) -> None:
  88. self._locks: dict[str, asyncio.Lock] = {}
  89. def lock_for(self, trace_id: str) -> asyncio.Lock:
  90. if trace_id not in self._locks:
  91. self._locks[trace_id] = asyncio.Lock()
  92. return self._locks[trace_id]
  93. # --- Agent HTTP 与 gateway_exec ---
  94. class AgentTraceHttpClient:
  95. """对 Agent ``/api/traces`` 的最小封装。"""
  96. def __init__(self, *, base_url: str, timeout: float) -> None:
  97. self._base = base_url.rstrip("/")
  98. self._timeout = timeout
  99. async def post_run(
  100. self,
  101. trace_id: str,
  102. *,
  103. messages: list[dict[str, Any]],
  104. gateway_exec: dict[str, Any] | None = None,
  105. ) -> tuple[int, Any]:
  106. body: dict[str, Any] = {"messages": messages}
  107. if gateway_exec:
  108. body["gateway_exec"] = gateway_exec
  109. async with httpx.AsyncClient(timeout=self._timeout) as client:
  110. r = await client.post(f"{self._base}/api/traces/{trace_id}/run", json=body)
  111. try:
  112. payload: Any = r.json()
  113. except Exception:
  114. payload = r.text
  115. return r.status_code, payload
  116. async def get_trace_status(self, trace_id: str) -> str | None:
  117. async with httpx.AsyncClient(timeout=self._timeout) as client:
  118. r = await client.get(f"{self._base}/api/traces/{trace_id}")
  119. if r.status_code != 200:
  120. return None
  121. try:
  122. data = r.json()
  123. except Exception:
  124. return None
  125. trace = data.get("trace")
  126. if isinstance(trace, dict):
  127. st = trace.get("status")
  128. return str(st) if st is not None else None
  129. return None
  130. async def post_stop(self, trace_id: str) -> bool:
  131. async with httpx.AsyncClient(timeout=self._timeout) as client:
  132. r = await client.post(f"{self._base}/api/traces/{trace_id}/stop")
  133. return r.status_code < 400
  134. class _GatewayExecResolver:
  135. def __init__(self, workspace_manager: WorkspaceManager, trace_manager: TraceManager) -> None:
  136. self._wm = workspace_manager
  137. self._trace_mgr = trace_manager
  138. def resolve(self, trace_id: str) -> dict[str, Any] | None:
  139. try:
  140. wid = self._trace_mgr.get_workspace_id(trace_id)
  141. except LifecycleError:
  142. return None
  143. cid = self._wm.get_workspace_container_id(wid)
  144. if not cid:
  145. return None
  146. return {
  147. "docker_container": cid,
  148. "container_user": "agent",
  149. "container_workdir": "/home/agent/workspace",
  150. }
  151. # --- 单任务执行管道 ---
  152. class _TaskExecutionPipeline:
  153. def __init__(
  154. self,
  155. *,
  156. store: _InMemoryTaskStore,
  157. trace_client: AgentTraceHttpClient,
  158. gateway_exec: _GatewayExecResolver,
  159. execution_context: ExecutionContext,
  160. poll_interval: float,
  161. poll_max_seconds: float,
  162. is_cancelled: Callable[[str], bool],
  163. ) -> None:
  164. self._store = store
  165. self._http = trace_client
  166. self._gateway_exec = gateway_exec
  167. self._ctx = execution_context
  168. self._poll_interval = poll_interval
  169. self._poll_max_seconds = poll_max_seconds
  170. self._is_cancelled = is_cancelled
  171. async def run_after_lock_acquired(self, task_id: str) -> None:
  172. rec = self._store.get(task_id)
  173. if rec is None:
  174. return
  175. if self._is_cancelled(task_id):
  176. rec.touch(status="cancelled", error_message="cancelled_before_run")
  177. return
  178. rec.touch(status="running")
  179. self._ctx.log_execution(task_id, {"event": "run_start"})
  180. gateway_exec = self._gateway_exec.resolve(rec.trace_id)
  181. messages = [{"role": "user", "content": rec.task_description}]
  182. try:
  183. status_code, payload = await self._http.post_run(
  184. rec.trace_id,
  185. messages=messages,
  186. gateway_exec=gateway_exec,
  187. )
  188. except Exception as e:
  189. logger.exception("executor post_run failed task_id=%s", task_id)
  190. rec.touch(status="failed", error_message=str(e))
  191. self._ctx.log_execution(task_id, {"event": "http_error", "error": str(e)})
  192. return
  193. if status_code == 409:
  194. rec.touch(status="failed", error_message="trace_already_running")
  195. return
  196. if status_code >= 400:
  197. detail = payload if isinstance(payload, str) else str(payload)
  198. rec.touch(status="failed", error_message=f"agent_http_{status_code}: {detail[:500]}")
  199. return
  200. self._ctx.log_execution(task_id, {"event": "run_accepted", "status_code": status_code})
  201. poll_kind, terminal, poll_err = await self._poll_until_terminal(task_id, rec)
  202. if poll_kind == "cancelled":
  203. return
  204. if poll_err == "poll_timeout":
  205. rec.touch(status="failed", error_message="poll_timeout", trace_terminal_status=terminal)
  206. elif terminal == "failed":
  207. rec.touch(status="failed", error_message="trace_failed", trace_terminal_status=terminal)
  208. elif terminal == "stopped":
  209. rec.touch(status="cancelled", error_message="trace_stopped", trace_terminal_status=terminal)
  210. elif terminal == "completed":
  211. rec.touch(status="completed", trace_terminal_status=terminal)
  212. elif terminal:
  213. rec.touch(status="completed", trace_terminal_status=terminal)
  214. else:
  215. rec.touch(status="failed", error_message="status_unknown")
  216. self._ctx.log_execution(
  217. task_id,
  218. {"event": "finished", "terminal": terminal, "err": poll_err},
  219. )
  220. async def _poll_until_terminal(
  221. self, task_id: str, rec: TaskRecord
  222. ) -> tuple[Literal["ok", "cancelled"], str | None, str | None]:
  223. elapsed = 0.0
  224. terminal: str | None = None
  225. poll_err: str | None = None
  226. while elapsed <= self._poll_max_seconds:
  227. if self._is_cancelled(task_id):
  228. await self._http.post_stop(rec.trace_id)
  229. rec.touch(
  230. status="cancelled",
  231. error_message="cancelled_during_run",
  232. trace_terminal_status="stopped",
  233. )
  234. self._ctx.log_execution(task_id, {"event": "cancelled_mid_poll"})
  235. return "cancelled", None, None
  236. st = await self._http.get_trace_status(rec.trace_id)
  237. if st and st in TERMINAL_STATUSES:
  238. terminal = st
  239. break
  240. await asyncio.sleep(self._poll_interval)
  241. elapsed += self._poll_interval
  242. else:
  243. poll_err = "poll_timeout"
  244. return "ok", terminal, poll_err
  245. # --- 编排入口 ---
  246. class TaskManager:
  247. """
  248. 校验 Trace、内存任务表、按 trace 串行执行、同步等待与取消。
  249. 内部类:存储、锁、HTTP、管道均在本模块实现。
  250. """
  251. def __init__(
  252. self,
  253. *,
  254. workspace_manager: WorkspaceManager,
  255. trace_manager: TraceManager,
  256. agent_api_base_url: str,
  257. http_timeout: float,
  258. poll_interval: float = 2.0,
  259. poll_max_seconds: float = 86400.0,
  260. ) -> None:
  261. self._trace_mgr = trace_manager
  262. self._store = _InMemoryTaskStore()
  263. self._trace_locks = _TraceSerialLocks()
  264. self._http = AgentTraceHttpClient(base_url=agent_api_base_url, timeout=http_timeout)
  265. self._gateway_exec = _GatewayExecResolver(workspace_manager, trace_manager)
  266. self._ctx = ExecutionContext(workspace_manager, trace_manager)
  267. self._pipeline = _TaskExecutionPipeline(
  268. store=self._store,
  269. trace_client=self._http,
  270. gateway_exec=self._gateway_exec,
  271. execution_context=self._ctx,
  272. poll_interval=poll_interval,
  273. poll_max_seconds=poll_max_seconds,
  274. is_cancelled=self._is_cancelled,
  275. )
  276. self._cancelled: set[str] = set()
  277. self._done_events: dict[str, asyncio.Event] = {}
  278. def _is_cancelled(self, task_id: str) -> bool:
  279. return task_id in self._cancelled
  280. @classmethod
  281. def from_env(
  282. cls,
  283. workspace_manager: WorkspaceManager,
  284. trace_manager: TraceManager,
  285. ) -> TaskManager:
  286. return cls(
  287. workspace_manager=workspace_manager,
  288. trace_manager=trace_manager,
  289. agent_api_base_url=env_str("GATEWAY_AGENT_API_BASE_URL", "http://127.0.0.1:8000"),
  290. http_timeout=env_float("GATEWAY_AGENT_API_TIMEOUT", 60.0),
  291. poll_interval=env_float("GATEWAY_EXECUTOR_POLL_INTERVAL", 2.0),
  292. poll_max_seconds=env_float("GATEWAY_EXECUTOR_POLL_MAX_SECONDS", 86400.0),
  293. )
  294. async def submit_task(
  295. self,
  296. trace_id: str,
  297. task_description: str,
  298. mode: RunMode = "async",
  299. metadata: dict[str, Any] | None = None,
  300. ) -> str:
  301. await self._trace_mgr.get_trace(trace_id)
  302. task_id = f"gtask-{uuid.uuid4()}"
  303. rec = TaskRecord(
  304. task_id=task_id,
  305. trace_id=trace_id,
  306. task_description=task_description,
  307. mode=mode,
  308. metadata=dict(metadata or {}),
  309. )
  310. self._store.put(rec)
  311. done_ev = asyncio.Event()
  312. self._done_events[task_id] = done_ev
  313. self._ctx.log_execution(
  314. task_id,
  315. {"event": "submitted", "trace_id": trace_id, "mode": mode},
  316. )
  317. asyncio.create_task(self._run_task_pipeline(task_id), name=f"executor:{task_id}")
  318. if mode == "sync":
  319. await done_ev.wait()
  320. return task_id
  321. async def _run_task_pipeline(self, task_id: str) -> None:
  322. try:
  323. rec = self._store.get(task_id)
  324. if not rec:
  325. return
  326. if task_id in self._cancelled:
  327. rec.touch(status="cancelled", error_message="cancelled_before_start")
  328. return
  329. async with self._trace_locks.lock_for(rec.trace_id):
  330. if task_id in self._cancelled:
  331. rec.touch(status="cancelled", error_message="cancelled_before_run")
  332. return
  333. await self._pipeline.run_after_lock_acquired(task_id)
  334. finally:
  335. self._finish(task_id)
  336. def _finish(self, task_id: str) -> None:
  337. ev = self._done_events.pop(task_id, None)
  338. if ev is not None:
  339. ev.set()
  340. def get_task(self, task_id: str) -> dict[str, Any]:
  341. return self._store.require(task_id).to_dict()
  342. def list_tasks(
  343. self,
  344. trace_id: str | None = None,
  345. status: str | None = None,
  346. ) -> list[dict[str, Any]]:
  347. return self._store.list_as_dicts(trace_id=trace_id, status=status)
  348. async def cancel_task(self, task_id: str) -> None:
  349. rec = self._store.get(task_id)
  350. if not rec:
  351. raise TaskNotFoundError(task_id)
  352. if rec.status in ("completed", "failed", "cancelled"):
  353. raise ExecutorError(f"任务已结束,无法取消: {rec.status}")
  354. self._cancelled.add(task_id)
  355. if rec.status == "running":
  356. try:
  357. await self._http.post_stop(rec.trace_id)
  358. except Exception as e:
  359. logger.warning("TaskManager cancel stop trace failed: %s", e)
  360. self._ctx.log_execution(task_id, {"event": "cancel_requested"})
  361. def get_execution_logs(self, task_id: str) -> list[dict[str, Any]]:
  362. if not self._store.contains(task_id):
  363. raise TaskNotFoundError(task_id)
  364. return self._ctx.get_logs(task_id)