registry.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. """
  2. Agent Registry
  3. 管理在线 Agent 的注册信息和连接
  4. """
  5. import asyncio
  6. from datetime import datetime, timedelta
  7. from typing import Dict, Optional, List
  8. from dataclasses import dataclass, field
  9. import logging
  10. logger = logging.getLogger(__name__)
  11. @dataclass
  12. class AgentConnection:
  13. """Agent 连接信息"""
  14. agent_uri: str
  15. connection_type: str # "websocket" | "http"
  16. websocket: Optional[object] = None # WebSocket 连接对象
  17. http_endpoint: Optional[str] = None # HTTP 端点
  18. capabilities: List[str] = field(default_factory=list)
  19. registered_at: datetime = field(default_factory=datetime.now)
  20. last_heartbeat: datetime = field(default_factory=datetime.now)
  21. metadata: Dict = field(default_factory=dict)
  22. class AgentRegistry:
  23. """Agent 注册表"""
  24. def __init__(self, heartbeat_timeout: int = 60):
  25. """
  26. Args:
  27. heartbeat_timeout: 心跳超时时间(秒)
  28. """
  29. self.agents: Dict[str, AgentConnection] = {}
  30. self.heartbeat_timeout = heartbeat_timeout
  31. self._cleanup_task = None
  32. async def start(self):
  33. """启动注册表(定期清理过期连接)"""
  34. self._cleanup_task = asyncio.create_task(self._cleanup_loop())
  35. logger.info("AgentRegistry started")
  36. async def stop(self):
  37. """停止注册表"""
  38. if self._cleanup_task:
  39. self._cleanup_task.cancel()
  40. try:
  41. await self._cleanup_task
  42. except asyncio.CancelledError:
  43. pass
  44. logger.info("AgentRegistry stopped")
  45. async def register(
  46. self,
  47. agent_uri: str,
  48. connection_type: str,
  49. websocket: Optional[object] = None,
  50. http_endpoint: Optional[str] = None,
  51. capabilities: Optional[List[str]] = None,
  52. metadata: Optional[Dict] = None
  53. ) -> AgentConnection:
  54. """
  55. 注册 Agent
  56. Args:
  57. agent_uri: Agent URI (agent://domain/id)
  58. connection_type: 连接类型 (websocket | http)
  59. websocket: WebSocket 连接对象
  60. http_endpoint: HTTP 端点
  61. capabilities: Agent 能力列表
  62. metadata: 额外元数据
  63. """
  64. connection = AgentConnection(
  65. agent_uri=agent_uri,
  66. connection_type=connection_type,
  67. websocket=websocket,
  68. http_endpoint=http_endpoint,
  69. capabilities=capabilities or [],
  70. metadata=metadata or {}
  71. )
  72. self.agents[agent_uri] = connection
  73. logger.info(f"Agent registered: {agent_uri} ({connection_type})")
  74. return connection
  75. async def unregister(self, agent_uri: str):
  76. """注销 Agent"""
  77. if agent_uri in self.agents:
  78. del self.agents[agent_uri]
  79. logger.info(f"Agent unregistered: {agent_uri}")
  80. async def heartbeat(self, agent_uri: str):
  81. """更新心跳时间"""
  82. if agent_uri in self.agents:
  83. self.agents[agent_uri].last_heartbeat = datetime.now()
  84. def lookup(self, agent_uri: str) -> Optional[AgentConnection]:
  85. """查找 Agent 连接信息"""
  86. return self.agents.get(agent_uri)
  87. def is_online(self, agent_uri: str) -> bool:
  88. """检查 Agent 是否在线"""
  89. connection = self.lookup(agent_uri)
  90. if not connection:
  91. return False
  92. # 检查心跳是否超时
  93. timeout = timedelta(seconds=self.heartbeat_timeout)
  94. return datetime.now() - connection.last_heartbeat < timeout
  95. def list_agents(
  96. self,
  97. connection_type: Optional[str] = None,
  98. online_only: bool = True
  99. ) -> List[AgentConnection]:
  100. """
  101. 列出 Agent
  102. Args:
  103. connection_type: 过滤连接类型
  104. online_only: 只返回在线的 Agent
  105. """
  106. agents = list(self.agents.values())
  107. if connection_type:
  108. agents = [a for a in agents if a.connection_type == connection_type]
  109. if online_only:
  110. timeout = timedelta(seconds=self.heartbeat_timeout)
  111. now = datetime.now()
  112. agents = [a for a in agents if now - a.last_heartbeat < timeout]
  113. return agents
  114. async def _cleanup_loop(self):
  115. """定期清理过期连接"""
  116. while True:
  117. try:
  118. await asyncio.sleep(30) # 每 30 秒检查一次
  119. await self._cleanup_expired()
  120. except asyncio.CancelledError:
  121. break
  122. except Exception as e:
  123. logger.error(f"Cleanup error: {e}")
  124. async def _cleanup_expired(self):
  125. """清理过期连接"""
  126. timeout = timedelta(seconds=self.heartbeat_timeout)
  127. now = datetime.now()
  128. expired = [
  129. uri for uri, conn in self.agents.items()
  130. if now - conn.last_heartbeat > timeout
  131. ]
  132. for uri in expired:
  133. await self.unregister(uri)
  134. logger.info(f"Agent expired: {uri}")