client.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. """
  2. Gateway Client
  3. Agent 端用于连接 Gateway 的客户端
  4. """
  5. import asyncio
  6. import json
  7. import logging
  8. from typing import Optional, Dict, Any, Callable
  9. import websockets
  10. logger = logging.getLogger(__name__)
  11. class GatewayClient:
  12. """Gateway 客户端"""
  13. def __init__(
  14. self,
  15. gateway_url: str,
  16. agent_uri: str,
  17. capabilities: Optional[list] = None,
  18. on_message: Optional[Callable] = None
  19. ):
  20. """
  21. Args:
  22. gateway_url: Gateway WebSocket URL (wss://gateway.com/gateway/connect)
  23. agent_uri: 本 Agent 的 URI
  24. capabilities: Agent 能力列表
  25. on_message: 收到消息时的回调函数
  26. """
  27. self.gateway_url = gateway_url
  28. self.agent_uri = agent_uri
  29. self.capabilities = capabilities or []
  30. self.on_message = on_message
  31. self.ws: Optional[websockets.WebSocketClientProtocol] = None
  32. self.connected = False
  33. self._heartbeat_task = None
  34. self._receive_task = None
  35. async def connect(self):
  36. """连接到 Gateway"""
  37. try:
  38. self.ws = await websockets.connect(self.gateway_url)
  39. # 发送注册消息
  40. await self.ws.send(json.dumps({
  41. "type": "register",
  42. "agent_uri": self.agent_uri,
  43. "capabilities": self.capabilities
  44. }))
  45. # 等待注册确认
  46. response = await self.ws.recv()
  47. msg = json.loads(response)
  48. if msg.get("type") == "registered":
  49. self.connected = True
  50. logger.info(f"Connected to Gateway: {self.agent_uri}")
  51. # 启动心跳和接收任务
  52. self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
  53. self._receive_task = asyncio.create_task(self._receive_loop())
  54. else:
  55. raise Exception(f"Registration failed: {msg}")
  56. except Exception as e:
  57. logger.error(f"Failed to connect to Gateway: {e}")
  58. raise
  59. async def disconnect(self):
  60. """断开连接"""
  61. self.connected = False
  62. if self._heartbeat_task:
  63. self._heartbeat_task.cancel()
  64. if self._receive_task:
  65. self._receive_task.cancel()
  66. if self.ws:
  67. await self.ws.close()
  68. logger.info(f"Disconnected from Gateway: {self.agent_uri}")
  69. async def send_result(self, task_id: str, result: Any):
  70. """发送任务结果"""
  71. if not self.connected:
  72. raise Exception("Not connected to Gateway")
  73. await self.ws.send(json.dumps({
  74. "type": "result",
  75. "task_id": task_id,
  76. "result": result
  77. }))
  78. async def _heartbeat_loop(self):
  79. """心跳循环"""
  80. while self.connected:
  81. try:
  82. await asyncio.sleep(30) # 每 30 秒发送心跳
  83. await self.ws.send(json.dumps({"type": "heartbeat"}))
  84. # 等待心跳确认
  85. # TODO: 添加超时处理
  86. except asyncio.CancelledError:
  87. break
  88. except Exception as e:
  89. logger.error(f"Heartbeat error: {e}")
  90. break
  91. async def _receive_loop(self):
  92. """接收消息循环"""
  93. while self.connected:
  94. try:
  95. data = await self.ws.recv()
  96. msg = json.loads(data)
  97. msg_type = msg.get("type")
  98. if msg_type == "heartbeat_ack":
  99. # 心跳确认,忽略
  100. pass
  101. elif msg_type == "message":
  102. # 收到消息,调用回调
  103. if self.on_message:
  104. await self.on_message(msg)
  105. else:
  106. logger.warning(f"Received message but no handler: {msg}")
  107. else:
  108. logger.warning(f"Unknown message type: {msg_type}")
  109. except asyncio.CancelledError:
  110. break
  111. except Exception as e:
  112. logger.error(f"Receive error: {e}")
  113. break
  114. async def create_gateway_client(
  115. gateway_url: str,
  116. agent_uri: str,
  117. capabilities: Optional[list] = None,
  118. on_message: Optional[Callable] = None
  119. ) -> GatewayClient:
  120. """
  121. 创建并连接 Gateway 客户端
  122. Args:
  123. gateway_url: Gateway WebSocket URL
  124. agent_uri: 本 Agent 的 URI
  125. capabilities: Agent 能力列表
  126. on_message: 收到消息时的回调函数
  127. Returns:
  128. 已连接的 GatewayClient
  129. """
  130. client = GatewayClient(
  131. gateway_url=gateway_url,
  132. agent_uri=agent_uri,
  133. capabilities=capabilities,
  134. on_message=on_message
  135. )
  136. await client.connect()
  137. return client