| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172 |
- """
- Gateway Client
- Agent 端用于连接 Gateway 的客户端
- """
- import asyncio
- import json
- import logging
- from typing import Optional, Dict, Any, Callable
- import websockets
- logger = logging.getLogger(__name__)
- class GatewayClient:
- """Gateway 客户端"""
- def __init__(
- self,
- gateway_url: str,
- agent_uri: str,
- capabilities: Optional[list] = None,
- on_message: Optional[Callable] = None
- ):
- """
- Args:
- gateway_url: Gateway WebSocket URL (wss://gateway.com/gateway/connect)
- agent_uri: 本 Agent 的 URI
- capabilities: Agent 能力列表
- on_message: 收到消息时的回调函数
- """
- self.gateway_url = gateway_url
- self.agent_uri = agent_uri
- self.capabilities = capabilities or []
- self.on_message = on_message
- self.ws: Optional[websockets.WebSocketClientProtocol] = None
- self.connected = False
- self._heartbeat_task = None
- self._receive_task = None
- async def connect(self):
- """连接到 Gateway"""
- try:
- self.ws = await websockets.connect(self.gateway_url)
- # 发送注册消息
- await self.ws.send(json.dumps({
- "type": "register",
- "agent_uri": self.agent_uri,
- "capabilities": self.capabilities
- }))
- # 等待注册确认
- response = await self.ws.recv()
- msg = json.loads(response)
- if msg.get("type") == "registered":
- self.connected = True
- logger.info(f"Connected to Gateway: {self.agent_uri}")
- # 启动心跳和接收任务
- self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
- self._receive_task = asyncio.create_task(self._receive_loop())
- else:
- raise Exception(f"Registration failed: {msg}")
- except Exception as e:
- logger.error(f"Failed to connect to Gateway: {e}")
- raise
- async def disconnect(self):
- """断开连接"""
- self.connected = False
- if self._heartbeat_task:
- self._heartbeat_task.cancel()
- if self._receive_task:
- self._receive_task.cancel()
- if self.ws:
- await self.ws.close()
- logger.info(f"Disconnected from Gateway: {self.agent_uri}")
- async def send_result(self, task_id: str, result: Any):
- """发送任务结果"""
- if not self.connected:
- raise Exception("Not connected to Gateway")
- await self.ws.send(json.dumps({
- "type": "result",
- "task_id": task_id,
- "result": result
- }))
- async def _heartbeat_loop(self):
- """心跳循环"""
- while self.connected:
- try:
- await asyncio.sleep(30) # 每 30 秒发送心跳
- await self.ws.send(json.dumps({"type": "heartbeat"}))
- # 等待心跳确认
- # TODO: 添加超时处理
- except asyncio.CancelledError:
- break
- except Exception as e:
- logger.error(f"Heartbeat error: {e}")
- break
- async def _receive_loop(self):
- """接收消息循环"""
- while self.connected:
- try:
- data = await self.ws.recv()
- msg = json.loads(data)
- msg_type = msg.get("type")
- if msg_type == "heartbeat_ack":
- # 心跳确认,忽略
- pass
- elif msg_type == "message":
- # 收到消息,调用回调
- if self.on_message:
- await self.on_message(msg)
- else:
- logger.warning(f"Received message but no handler: {msg}")
- else:
- logger.warning(f"Unknown message type: {msg_type}")
- except asyncio.CancelledError:
- break
- except Exception as e:
- logger.error(f"Receive error: {e}")
- break
- async def create_gateway_client(
- gateway_url: str,
- agent_uri: str,
- capabilities: Optional[list] = None,
- on_message: Optional[Callable] = None
- ) -> GatewayClient:
- """
- 创建并连接 Gateway 客户端
- Args:
- gateway_url: Gateway WebSocket URL
- agent_uri: 本 Agent 的 URI
- capabilities: Agent 能力列表
- on_message: 收到消息时的回调函数
- Returns:
- 已连接的 GatewayClient
- """
- client = GatewayClient(
- gateway_url=gateway_url,
- agent_uri=agent_uri,
- capabilities=capabilities,
- on_message=on_message
- )
- await client.connect()
- return client
|