| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- """
- Agent Registry
- 管理在线 Agent 的注册信息和连接
- """
- import asyncio
- from datetime import datetime, timedelta
- from typing import Dict, Optional, List
- from dataclasses import dataclass, field
- import logging
- logger = logging.getLogger(__name__)
- @dataclass
- class AgentConnection:
- """Agent 连接信息"""
- agent_uri: str
- connection_type: str # "websocket" | "http"
- websocket: Optional[object] = None # WebSocket 连接对象
- http_endpoint: Optional[str] = None # HTTP 端点
- capabilities: List[str] = field(default_factory=list)
- registered_at: datetime = field(default_factory=datetime.now)
- last_heartbeat: datetime = field(default_factory=datetime.now)
- metadata: Dict = field(default_factory=dict)
- class AgentRegistry:
- """Agent 注册表"""
- def __init__(self, heartbeat_timeout: int = 60):
- """
- Args:
- heartbeat_timeout: 心跳超时时间(秒)
- """
- self.agents: Dict[str, AgentConnection] = {}
- self.heartbeat_timeout = heartbeat_timeout
- self._cleanup_task = None
- async def start(self):
- """启动注册表(定期清理过期连接)"""
- self._cleanup_task = asyncio.create_task(self._cleanup_loop())
- logger.info("AgentRegistry started")
- async def stop(self):
- """停止注册表"""
- if self._cleanup_task:
- self._cleanup_task.cancel()
- try:
- await self._cleanup_task
- except asyncio.CancelledError:
- pass
- logger.info("AgentRegistry stopped")
- async def register(
- self,
- agent_uri: str,
- connection_type: str,
- websocket: Optional[object] = None,
- http_endpoint: Optional[str] = None,
- capabilities: Optional[List[str]] = None,
- metadata: Optional[Dict] = None
- ) -> AgentConnection:
- """
- 注册 Agent
- Args:
- agent_uri: Agent URI (agent://domain/id)
- connection_type: 连接类型 (websocket | http)
- websocket: WebSocket 连接对象
- http_endpoint: HTTP 端点
- capabilities: Agent 能力列表
- metadata: 额外元数据
- """
- connection = AgentConnection(
- agent_uri=agent_uri,
- connection_type=connection_type,
- websocket=websocket,
- http_endpoint=http_endpoint,
- capabilities=capabilities or [],
- metadata=metadata or {}
- )
- self.agents[agent_uri] = connection
- logger.info(f"Agent registered: {agent_uri} ({connection_type})")
- return connection
- async def unregister(self, agent_uri: str):
- """注销 Agent"""
- if agent_uri in self.agents:
- del self.agents[agent_uri]
- logger.info(f"Agent unregistered: {agent_uri}")
- async def heartbeat(self, agent_uri: str):
- """更新心跳时间"""
- if agent_uri in self.agents:
- self.agents[agent_uri].last_heartbeat = datetime.now()
- def lookup(self, agent_uri: str) -> Optional[AgentConnection]:
- """查找 Agent 连接信息"""
- return self.agents.get(agent_uri)
- def is_online(self, agent_uri: str) -> bool:
- """检查 Agent 是否在线"""
- connection = self.lookup(agent_uri)
- if not connection:
- return False
- # 检查心跳是否超时
- timeout = timedelta(seconds=self.heartbeat_timeout)
- return datetime.now() - connection.last_heartbeat < timeout
- def list_agents(
- self,
- connection_type: Optional[str] = None,
- online_only: bool = True
- ) -> List[AgentConnection]:
- """
- 列出 Agent
- Args:
- connection_type: 过滤连接类型
- online_only: 只返回在线的 Agent
- """
- agents = list(self.agents.values())
- if connection_type:
- agents = [a for a in agents if a.connection_type == connection_type]
- if online_only:
- timeout = timedelta(seconds=self.heartbeat_timeout)
- now = datetime.now()
- agents = [a for a in agents if now - a.last_heartbeat < timeout]
- return agents
- async def _cleanup_loop(self):
- """定期清理过期连接"""
- while True:
- try:
- await asyncio.sleep(30) # 每 30 秒检查一次
- await self._cleanup_expired()
- except asyncio.CancelledError:
- break
- except Exception as e:
- logger.error(f"Cleanup error: {e}")
- async def _cleanup_expired(self):
- """清理过期连接"""
- timeout = timedelta(seconds=self.heartbeat_timeout)
- now = datetime.now()
- expired = [
- uri for uri, conn in self.agents.items()
- if now - conn.last_heartbeat > timeout
- ]
- for uri in expired:
- await self.unregister(uri)
- logger.info(f"Agent expired: {uri}")
|