| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636 |
- """Docker 容器管理"""
- from __future__ import annotations
- import base64
- import json
- import logging
- import os
- import socket
- import threading
- import time
- import uuid
- from datetime import datetime, timezone
- from pathlib import Path
- from typing import Any, Callable
- import docker
- import docker.errors
- import docker.types
- import httpx
- from tool_agent.config import settings
- from tool_agent.models import ContainerInfo, ContainerStatus, ToolMeta
- logger = logging.getLogger(__name__)
- class ContainerStore:
- """容器状态 JSON 持久化层 — 替代 sandbox_repository 的 MySQL"""
- def __init__(self, path: Path | None = None) -> None:
- self._path = path or (settings.data_dir / "containers.json")
- self._lock = threading.Lock()
- def _load(self) -> list[dict[str, Any]]:
- if not self._path.exists():
- return []
- data = json.loads(self._path.read_text(encoding="utf-8"))
- return data.get("containers", [])
- def _save(self, containers: list[dict[str, Any]]) -> None:
- self._path.parent.mkdir(parents=True, exist_ok=True)
- self._path.write_text(
- json.dumps({"containers": containers}, indent=2, ensure_ascii=False, default=str),
- encoding="utf-8",
- )
- def create(self, info: ContainerInfo) -> None:
- with self._lock:
- containers = self._load()
- containers.append(info.model_dump(mode="json"))
- self._save(containers)
- def get(self, container_id: str) -> ContainerInfo | None:
- for item in self._load():
- if item["container_id"] == container_id and item["status"] == ContainerStatus.RUNNING:
- return ContainerInfo(**item)
- return None
- def get_all_active(self) -> list[ContainerInfo]:
- return [
- ContainerInfo(**item)
- for item in self._load()
- if item["status"] == ContainerStatus.RUNNING
- ]
- def update_last_accessed(self, container_id: str) -> None:
- with self._lock:
- containers = self._load()
- for item in containers:
- if item["container_id"] == container_id and item["status"] == ContainerStatus.RUNNING:
- item["last_accessed"] = datetime.now(timezone.utc).isoformat()
- break
- self._save(containers)
- def mark_destroyed(self, container_id: str) -> None:
- with self._lock:
- containers = self._load()
- for item in containers:
- if item["container_id"] == container_id and item["status"] == ContainerStatus.RUNNING:
- item["status"] = ContainerStatus.DESTROYED
- item["destroyed_at"] = datetime.now(timezone.utc).isoformat()
- break
- self._save(containers)
- def exists(self, container_id: str) -> bool:
- return any(
- item["container_id"] == container_id and item["status"] == ContainerStatus.RUNNING
- for item in self._load()
- )
- def get_expired(self, ttl_seconds: int) -> list[ContainerInfo]:
- now = datetime.now(timezone.utc)
- expired = []
- for item in self._load():
- if item["status"] != ContainerStatus.RUNNING:
- continue
- last = item.get("last_accessed") or item.get("created_at")
- if last:
- last_dt = datetime.fromisoformat(last) if isinstance(last, str) else last
- if last_dt.tzinfo is None:
- last_dt = last_dt.replace(tzinfo=timezone.utc)
- if (now - last_dt).total_seconds() > ttl_seconds:
- expired.append(ContainerInfo(**item))
- return expired
- def count_active(self) -> int:
- return sum(1 for item in self._load() if item["status"] == ContainerStatus.RUNNING)
- class DockerRunner:
- """Docker 容器完整生命周期管理
- 以 sandbox_manager.py 为蓝本,提供:
- - Docker 客户端懒加载
- - 容器创建(端口映射、资源限制、GPU)
- - 容器内命令执行(前台带超时 + 后台)
- - HTTP 调用容器内工具服务
- - 容器启停/销毁/重建
- - 健康检查
- - 线程安全容器缓存
- - 启动时从 JSON 恢复
- """
- def __init__(self, lazy_init: bool = True) -> None:
- self._docker_client: docker.DockerClient | None = None
- self._container_cache: dict[str, docker.models.containers.Container] = {}
- self._lock = threading.Lock()
- self._on_destroy_callbacks: list[Callable[[str], None]] = []
- self._store = ContainerStore()
- if not lazy_init:
- self._init_docker()
- # ---- Docker 客户端 ----
- @property
- def client(self) -> docker.DockerClient:
- if self._docker_client is None:
- self._init_docker()
- return self._docker_client
- def _init_docker(self) -> None:
- """连接 Docker(本地或远程)"""
- if settings.docker_host:
- # 远程 Docker:通过 SSH 连接
- ssh_key = settings.docker_ssh_key
- remote_host = settings.docker_host
- # Docker SDK 原生支持 SSH 连接
- # 格式:ssh://user@host
- docker_url = f"ssh://root@{remote_host}"
- logger.info(f"Connecting to remote Docker via SSH: {docker_url}")
- # 设置 SSH 密钥环境变量(Docker SDK 会使用)
- import os
- os.environ["DOCKER_SSH_KEY"] = ssh_key
- try:
- self._docker_client = docker.DockerClient(
- base_url=docker_url,
- timeout=30,
- use_ssh_client=True, # 使用系统的 SSH 客户端
- )
- # 测试连接
- self._docker_client.ping()
- logger.info(f"Successfully connected to remote Docker at {remote_host}")
- except Exception as e:
- logger.error(f"Failed to connect to remote Docker: {e}")
- raise RuntimeError(f"Cannot connect to remote Docker at {remote_host}: {e}")
- else:
- # 本地 Docker
- self._docker_client = docker.from_env()
- self._ensure_base_image()
- self._restore_container_cache()
- def _ensure_base_image(self) -> None:
- """检查基础镜像是否存在,不存在则从 Dockerfile 构建"""
- image_name = settings.docker_base_image
- try:
- self.client.images.get(image_name)
- logger.info(f"Base image '{image_name}' found locally.")
- except docker.errors.ImageNotFound:
- logger.info(f"Base image '{image_name}' not found. Building...")
- dockerfile_dir = settings.tools_dir / "docker"
- dockerfile_path = dockerfile_dir / "Dockerfile.sandbox"
- if not dockerfile_path.exists():
- logger.warning(f"Dockerfile '{dockerfile_path}' not found, skipping build.")
- return
- try:
- image, build_logs = self.client.images.build(
- path=str(dockerfile_dir),
- dockerfile="Dockerfile.sandbox",
- tag=image_name,
- rm=True,
- )
- for chunk in build_logs:
- if "stream" in chunk:
- logger.debug(chunk["stream"].strip())
- logger.info(f"Successfully built '{image_name}'.")
- except Exception as e:
- logger.error(f"Failed to build base image: {e}")
- def _restore_container_cache(self) -> None:
- """启动时从 JSON 恢复活跃容器到内存缓存"""
- for info in self._store.get_all_active():
- try:
- container = self.client.containers.get(info.container_id)
- if container.status == "running":
- with self._lock:
- self._container_cache[info.container_id] = container
- logger.info(f"Restored container cache: {info.container_id[:12]}")
- else:
- self._store.mark_stopped(info.container_id)
- logger.info(f"Container not running, marked stopped: {info.container_id[:12]}")
- except docker.errors.NotFound:
- self._store.mark_destroyed(info.container_id)
- logger.warning(f"Container not found, marked destroyed: {info.container_id[:12]}")
- # ---- 回调 ----
- def add_on_destroy_callback(self, callback: Callable[[str], None]) -> None:
- self._on_destroy_callbacks.append(callback)
- def _trigger_destroy_callbacks(self, container_id: str) -> None:
- for cb in self._on_destroy_callbacks:
- try:
- cb(container_id)
- except Exception as e:
- logger.error(f"Destroy callback failed: {e}")
- # ---- 端口 ----
- @staticmethod
- def _get_free_port() -> int:
- """获取一个空闲的宿主机端口"""
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.bind(("", 0))
- return s.getsockname()[1]
- # ---- 容器生命周期 ----
- def create_container(
- self,
- tool_id: str = "",
- image: str | None = None,
- mem_limit: str | None = None,
- nano_cpus: int | None = None,
- ports: list[int] | None = None,
- volumes: dict[str, str] | None = None,
- use_gpu: bool = False,
- gpu_count: int = -1,
- ) -> dict[str, Any]:
- """创建新容器
- Args:
- tool_id: 关联的工具 ID
- image: 镜像名称,默认使用 settings.docker_base_image
- mem_limit: 内存限制,如 "1g"
- nano_cpus: CPU 限制,1_000_000_000 = 1 CPU
- ports: 需要映射的容器端口列表,如 [8080, 3000]
- volumes: 目录挂载,{宿主机路径: 容器路径},如 {"/home/user/project": "/app"}
- use_gpu: 是否启用 GPU
- gpu_count: GPU 数量,-1 表示全部
- """
- image = image or settings.docker_base_image
- mem_limit = mem_limit or settings.docker_mem_limit
- nano_cpus = nano_cpus or settings.docker_nano_cpus
- # 容器数量上限检查
- active_count = self._store.count_active()
- if active_count >= settings.hot_tool_max_containers:
- return {"error": f"Container limit reached ({settings.hot_tool_max_containers}). Destroy unused containers first."}
- try:
- # 端口映射
- port_bindings: dict[str, int] = {}
- port_mapping: dict[int, int] = {}
- if ports:
- for container_port in ports:
- host_port = self._get_free_port()
- port_bindings[f"{container_port}/tcp"] = host_port
- port_mapping[container_port] = host_port
- # 目录挂载
- docker_volumes = {}
- if volumes:
- for host_path, container_path in volumes.items():
- docker_volumes[host_path] = {"bind": container_path, "mode": "ro"}
- # GPU 配置
- device_requests = None
- if use_gpu:
- device_requests = [
- docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]])
- ]
- container = self.client.containers.run(
- image,
- command="tail -f /dev/null",
- detach=True,
- ports=port_bindings or None,
- volumes=docker_volumes or None,
- working_dir="/app",
- mem_limit=mem_limit,
- nano_cpus=nano_cpus,
- device_requests=device_requests,
- security_opt=["no-new-privileges"],
- )
- now = datetime.now(timezone.utc)
- info = ContainerInfo(
- container_id=container.id,
- tool_id=tool_id,
- image=image,
- port_mapping=port_mapping,
- volumes=volumes or {},
- mem_limit=mem_limit,
- nano_cpus=nano_cpus,
- use_gpu=use_gpu,
- gpu_count=gpu_count,
- created_at=now,
- last_accessed=now,
- )
- self._store.create(info)
- with self._lock:
- self._container_cache[container.id] = container
- logger.info(f"Created container {container.id[:12]} for tool '{tool_id}', ports={port_mapping}")
- return {
- "container_id": container.id,
- "tool_id": tool_id,
- "port_mapping": port_mapping,
- "message": "Container created.",
- }
- except Exception as e:
- logger.error(f"Failed to create container: {e}")
- return {"error": str(e)}
- def _get_container(self, container_id: str) -> docker.models.containers.Container | None:
- """从缓存或 Docker 获取容器对象"""
- with self._lock:
- container = self._container_cache.get(container_id)
- if container:
- return container
- try:
- container = self.client.containers.get(container_id)
- with self._lock:
- self._container_cache[container_id] = container
- return container
- except docker.errors.NotFound:
- self._store.mark_destroyed(container_id)
- return None
- def run_command(
- self,
- container_id: str,
- command: str,
- background: bool = False,
- timeout: int = 120,
- ) -> dict[str, Any]:
- """在容器内执行命令
- Args:
- container_id: 容器 ID
- command: Shell 命令
- background: 是否后台执行
- timeout: 前台命令超时秒数
- """
- if not self._store.exists(container_id):
- return {"error": "Container not found"}
- self._store.update_last_accessed(container_id)
- container = self._get_container(container_id)
- if not container:
- return {"error": "Container object not found"}
- logger.info(f"Running in {container_id[:12]}: {command} (bg={background})")
- try:
- if background:
- log_file = f"background_{uuid.uuid4().hex[:8]}.log"
- encoded_cmd = base64.b64encode(command.encode()).decode()
- safe_cmd = f"echo {encoded_cmd} | base64 -d | nohup sh > /app/{log_file} 2>&1 &"
- container.exec_run(["sh", "-c", safe_cmd], detach=True)
- return {
- "status": "success",
- "message": "Command started in background",
- "log_file": f"/app/{log_file}",
- }
- else:
- result_box: dict[str, Any] = {}
- def _exec():
- try:
- result_box["exec"] = container.exec_run(
- ["sh", "-c", command], demux=True,
- )
- except Exception as e:
- result_box["error"] = str(e)
- thread = threading.Thread(target=_exec, daemon=True)
- thread.start()
- thread.join(timeout=timeout)
- if thread.is_alive():
- return {"error": f"Command timeout after {timeout}s"}
- if "error" in result_box:
- return {"error": result_box["error"]}
- exec_result = result_box["exec"]
- stdout, stderr = exec_result.output
- return {
- "exit_code": exec_result.exit_code,
- "stdout": stdout.decode("utf-8", errors="replace") if stdout else "",
- "stderr": stderr.decode("utf-8", errors="replace") if stderr else "",
- }
- except Exception as e:
- return {"error": str(e)}
- def destroy_container(self, container_id: str) -> dict[str, Any]:
- """销毁容器并释放资源"""
- if not self._store.exists(container_id):
- return {"error": "Container not found"}
- with self._lock:
- container = self._container_cache.pop(container_id, None)
- if not container:
- try:
- container = self.client.containers.get(container_id)
- except docker.errors.NotFound:
- self._store.mark_destroyed(container_id)
- return {"error": "Container not found in Docker"}
- self._store.mark_destroyed(container_id)
- self._trigger_destroy_callbacks(container_id)
- try:
- container.remove(force=True)
- logger.info(f"Destroyed container {container_id[:12]}")
- return {"status": "success", "message": f"Container {container_id[:12]} destroyed"}
- except Exception as e:
- return {"error": str(e)}
- def rebuild_with_ports(
- self,
- container_id: str,
- ports: list[int],
- mem_limit: str | None = None,
- nano_cpus: int | None = None,
- use_gpu: bool = False,
- gpu_count: int = -1,
- ) -> dict[str, Any]:
- """重建容器并应用新端口映射,保留文件系统状态
- 通过 commit → 重建 → 清理临时镜像实现。
- """
- container = self._get_container(container_id)
- if not container:
- return {"error": "Container not found"}
- info = self._store.get(container_id)
- mem_limit = mem_limit or (info.mem_limit if info else settings.docker_mem_limit)
- nano_cpus = nano_cpus or (info.nano_cpus if info else settings.docker_nano_cpus)
- tool_id = info.tool_id if info else ""
- old_volumes = info.volumes if info else {}
- try:
- # 1. commit 当前容器为临时镜像
- temp_tag = f"sandbox-temp-{uuid.uuid4().hex[:8]}"
- logger.info(f"Committing {container_id[:12]} → {temp_tag}")
- container.commit(repository=temp_tag)
- # 2. 新端口映射
- port_bindings: dict[str, int] = {}
- port_mapping: dict[int, int] = {}
- for p in ports:
- hp = self._get_free_port()
- port_bindings[f"{p}/tcp"] = hp
- port_mapping[p] = hp
- # 恢复目录挂载
- docker_volumes = {}
- if old_volumes:
- for host_path, container_path in old_volumes.items():
- docker_volumes[host_path] = {"bind": container_path, "mode": "ro"}
- device_requests = None
- if use_gpu:
- device_requests = [
- docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]])
- ]
- # 3. 创建新容器
- new_container = self.client.containers.run(
- temp_tag,
- command="tail -f /dev/null",
- detach=True,
- ports=port_bindings or None,
- volumes=docker_volumes or None,
- working_dir="/app",
- mem_limit=mem_limit,
- nano_cpus=nano_cpus,
- device_requests=device_requests,
- security_opt=["no-new-privileges"],
- )
- # 4. 清理旧容器
- self._store.mark_destroyed(container_id)
- with self._lock:
- self._container_cache.pop(container_id, None)
- container.remove(force=True)
- # 5. 清理临时镜像
- try:
- self.client.images.remove(temp_tag, force=True)
- except Exception as e:
- logger.warning(f"Failed to remove temp image: {e}")
- # 6. 保存新容器
- now = datetime.now(timezone.utc)
- new_info = ContainerInfo(
- container_id=new_container.id,
- tool_id=tool_id,
- image=info.image if info else settings.docker_base_image,
- port_mapping=port_mapping,
- volumes=old_volumes,
- mem_limit=mem_limit,
- nano_cpus=nano_cpus,
- use_gpu=use_gpu,
- gpu_count=gpu_count,
- created_at=now,
- last_accessed=now,
- )
- self._store.create(new_info)
- with self._lock:
- self._container_cache[new_container.id] = new_container
- logger.info(f"Rebuilt {container_id[:12]} → {new_container.id[:12]}, ports={port_mapping}")
- return {
- "old_container_id": container_id,
- "new_container_id": new_container.id,
- "port_mapping": port_mapping,
- "message": "Container rebuilt with new port mappings. All files preserved.",
- }
- except Exception as e:
- logger.error(f"Failed to rebuild container: {e}")
- return {"error": str(e)}
- def start_container(self, container_id: str) -> bool:
- """启动已停止的容器"""
- container = self._get_container(container_id)
- if not container:
- return False
- try:
- container.start()
- logger.info(f"Started container {container_id[:12]}")
- return True
- except Exception as e:
- logger.error(f"Failed to start container: {e}")
- return False
- def stop_container(self, container_id: str) -> bool:
- """停止运行中的容器"""
- container = self._get_container(container_id)
- if not container:
- return False
- try:
- container.stop(timeout=10)
- logger.info(f"Stopped container {container_id[:12]}")
- return True
- except Exception as e:
- logger.error(f"Failed to stop container: {e}")
- return False
- # ---- 工具调用 (HTTP) ----
- async def run(self, tool: ToolMeta, params: dict[str, Any], stream: bool = False) -> dict[str, Any]:
- """通过 HTTP 调用容器内工具服务"""
- # 从注册表中找到该工具对应的容器端口
- for info in self._store.get_all_active():
- if info.tool_id == tool.tool_id and info.port_mapping:
- # 取第一个映射端口作为服务端口
- host_port = next(iter(info.port_mapping.values()))
- self._store.update_last_accessed(info.container_id)
- url = f"http://localhost:{host_port}/run"
- payload = {"params": params, "stream": stream}
- try:
- async with httpx.AsyncClient(timeout=60) as client:
- resp = await client.post(url, json=payload)
- return resp.json()
- except Exception as e:
- return {"status": "error", "error": str(e)}
- return {"status": "error", "error": f"No running container for tool '{tool.tool_id}'"}
- async def health_check(self, container_id: str) -> bool:
- """HTTP 健康检查"""
- info = self._store.get(container_id)
- if not info or not info.port_mapping:
- return False
- host_port = next(iter(info.port_mapping.values()))
- try:
- async with httpx.AsyncClient(timeout=5) as client:
- resp = await client.get(f"http://localhost:{host_port}/health")
- return resp.status_code == 200
- except Exception:
- return False
- # ---- 自动清理 ----
- def cleanup_expired(self) -> list[str]:
- """清理超过 TTL 的容器,返回被清理的 container_id 列表"""
- expired = self._store.get_expired(settings.docker_ttl_seconds)
- cleaned = []
- for info in expired:
- logger.info(f"Auto-cleaning expired container {info.container_id[:12]}")
- result = self.destroy_container(info.container_id)
- if "error" not in result:
- cleaned.append(info.container_id)
- return cleaned
- # ---- 查询 ----
- def list_active(self) -> list[ContainerInfo]:
- return self._store.get_all_active()
- def get_container_info(self, container_id: str) -> ContainerInfo | None:
- return self._store.get(container_id)
|