""" Agent Registry 管理在线 Agent 的注册信息和连接 """ import asyncio from datetime import datetime, timedelta from typing import Dict, Optional, List from dataclasses import dataclass, field import logging logger = logging.getLogger(__name__) @dataclass class AgentConnection: """Agent 连接信息""" agent_uri: str connection_type: str # "websocket" | "http" websocket: Optional[object] = None # WebSocket 连接对象 http_endpoint: Optional[str] = None # HTTP 端点 capabilities: List[str] = field(default_factory=list) registered_at: datetime = field(default_factory=datetime.now) last_heartbeat: datetime = field(default_factory=datetime.now) metadata: Dict = field(default_factory=dict) class AgentRegistry: """Agent 注册表""" def __init__(self, heartbeat_timeout: int = 60): """ Args: heartbeat_timeout: 心跳超时时间(秒) """ self.agents: Dict[str, AgentConnection] = {} self.heartbeat_timeout = heartbeat_timeout self._cleanup_task = None async def start(self): """启动注册表(定期清理过期连接)""" self._cleanup_task = asyncio.create_task(self._cleanup_loop()) logger.info("AgentRegistry started") async def stop(self): """停止注册表""" if self._cleanup_task: self._cleanup_task.cancel() try: await self._cleanup_task except asyncio.CancelledError: pass logger.info("AgentRegistry stopped") async def register( self, agent_uri: str, connection_type: str, websocket: Optional[object] = None, http_endpoint: Optional[str] = None, capabilities: Optional[List[str]] = None, metadata: Optional[Dict] = None ) -> AgentConnection: """ 注册 Agent Args: agent_uri: Agent URI (agent://domain/id) connection_type: 连接类型 (websocket | http) websocket: WebSocket 连接对象 http_endpoint: HTTP 端点 capabilities: Agent 能力列表 metadata: 额外元数据 """ connection = AgentConnection( agent_uri=agent_uri, connection_type=connection_type, websocket=websocket, http_endpoint=http_endpoint, capabilities=capabilities or [], metadata=metadata or {} ) self.agents[agent_uri] = connection logger.info(f"Agent registered: {agent_uri} ({connection_type})") return connection async def unregister(self, agent_uri: str): """注销 Agent""" if agent_uri in self.agents: del self.agents[agent_uri] logger.info(f"Agent unregistered: {agent_uri}") async def heartbeat(self, agent_uri: str): """更新心跳时间""" if agent_uri in self.agents: self.agents[agent_uri].last_heartbeat = datetime.now() def lookup(self, agent_uri: str) -> Optional[AgentConnection]: """查找 Agent 连接信息""" return self.agents.get(agent_uri) def is_online(self, agent_uri: str) -> bool: """检查 Agent 是否在线""" connection = self.lookup(agent_uri) if not connection: return False # 检查心跳是否超时 timeout = timedelta(seconds=self.heartbeat_timeout) return datetime.now() - connection.last_heartbeat < timeout def list_agents( self, connection_type: Optional[str] = None, online_only: bool = True ) -> List[AgentConnection]: """ 列出 Agent Args: connection_type: 过滤连接类型 online_only: 只返回在线的 Agent """ agents = list(self.agents.values()) if connection_type: agents = [a for a in agents if a.connection_type == connection_type] if online_only: timeout = timedelta(seconds=self.heartbeat_timeout) now = datetime.now() agents = [a for a in agents if now - a.last_heartbeat < timeout] return agents async def _cleanup_loop(self): """定期清理过期连接""" while True: try: await asyncio.sleep(30) # 每 30 秒检查一次 await self._cleanup_expired() except asyncio.CancelledError: break except Exception as e: logger.error(f"Cleanup error: {e}") async def _cleanup_expired(self): """清理过期连接""" timeout = timedelta(seconds=self.heartbeat_timeout) now = datetime.now() expired = [ uri for uri, conn in self.agents.items() if now - conn.last_heartbeat > timeout ] for uri in expired: await self.unregister(uri) logger.info(f"Agent expired: {uri}")