""" Gateway Client Agent 端用于连接 Gateway 的客户端 """ import asyncio import json import logging from typing import Optional, Dict, Any, Callable import websockets logger = logging.getLogger(__name__) class GatewayClient: """Gateway 客户端""" def __init__( self, gateway_url: str, agent_uri: str, capabilities: Optional[list] = None, on_message: Optional[Callable] = None ): """ Args: gateway_url: Gateway WebSocket URL (wss://gateway.com/gateway/connect) agent_uri: 本 Agent 的 URI capabilities: Agent 能力列表 on_message: 收到消息时的回调函数 """ self.gateway_url = gateway_url self.agent_uri = agent_uri self.capabilities = capabilities or [] self.on_message = on_message self.ws: Optional[websockets.WebSocketClientProtocol] = None self.connected = False self._heartbeat_task = None self._receive_task = None async def connect(self): """连接到 Gateway""" try: self.ws = await websockets.connect(self.gateway_url) # 发送注册消息 await self.ws.send(json.dumps({ "type": "register", "agent_uri": self.agent_uri, "capabilities": self.capabilities })) # 等待注册确认 response = await self.ws.recv() msg = json.loads(response) if msg.get("type") == "registered": self.connected = True logger.info(f"Connected to Gateway: {self.agent_uri}") # 启动心跳和接收任务 self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) self._receive_task = asyncio.create_task(self._receive_loop()) else: raise Exception(f"Registration failed: {msg}") except Exception as e: logger.error(f"Failed to connect to Gateway: {e}") raise async def disconnect(self): """断开连接""" self.connected = False if self._heartbeat_task: self._heartbeat_task.cancel() if self._receive_task: self._receive_task.cancel() if self.ws: await self.ws.close() logger.info(f"Disconnected from Gateway: {self.agent_uri}") async def send_result(self, task_id: str, result: Any): """发送任务结果""" if not self.connected: raise Exception("Not connected to Gateway") await self.ws.send(json.dumps({ "type": "result", "task_id": task_id, "result": result })) async def _heartbeat_loop(self): """心跳循环""" while self.connected: try: await asyncio.sleep(30) # 每 30 秒发送心跳 await self.ws.send(json.dumps({"type": "heartbeat"})) # 等待心跳确认 # TODO: 添加超时处理 except asyncio.CancelledError: break except Exception as e: logger.error(f"Heartbeat error: {e}") break async def _receive_loop(self): """接收消息循环""" while self.connected: try: data = await self.ws.recv() msg = json.loads(data) msg_type = msg.get("type") if msg_type == "heartbeat_ack": # 心跳确认,忽略 pass elif msg_type == "message": # 收到消息,调用回调 if self.on_message: await self.on_message(msg) else: logger.warning(f"Received message but no handler: {msg}") else: logger.warning(f"Unknown message type: {msg_type}") except asyncio.CancelledError: break except Exception as e: logger.error(f"Receive error: {e}") break async def create_gateway_client( gateway_url: str, agent_uri: str, capabilities: Optional[list] = None, on_message: Optional[Callable] = None ) -> GatewayClient: """ 创建并连接 Gateway 客户端 Args: gateway_url: Gateway WebSocket URL agent_uri: 本 Agent 的 URI capabilities: Agent 能力列表 on_message: 收到消息时的回调函数 Returns: 已连接的 GatewayClient """ client = GatewayClient( gateway_url=gateway_url, agent_uri=agent_uri, capabilities=capabilities, on_message=on_message ) await client.connect() return client