router.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. """
  2. Gateway Router
  3. Agent WebSocket 注册、消息发送、在线状态查询。
  4. """
  5. from __future__ import annotations
  6. import json
  7. import logging
  8. from typing import Any, Dict, Optional
  9. from fastapi import APIRouter, WebSocket, WebSocketDisconnect, HTTPException
  10. from pydantic import BaseModel
  11. from .registry import AgentRegistry
  12. logger = logging.getLogger(__name__)
  13. class SendMessageRequest(BaseModel):
  14. """发送消息请求"""
  15. to: str # 目标 Agent URI
  16. content: Any # 消息内容(字符串或多模态数组)
  17. conversation_id: Optional[str] = None
  18. metadata: Optional[Dict] = None
  19. class SendMessageResponse(BaseModel):
  20. """发送消息响应"""
  21. message_id: str
  22. conversation_id: str
  23. status: str # "sent" | "queued"
  24. class AgentStatusResponse(BaseModel):
  25. """Agent 状态响应"""
  26. agent_uri: str
  27. status: str # "online" | "offline"
  28. last_seen: Optional[str] = None
  29. class GatewayRouter:
  30. """Gateway 路由器"""
  31. def __init__(self, registry: AgentRegistry):
  32. self.registry = registry
  33. self.router = APIRouter(prefix="/gateway", tags=["gateway"])
  34. # 注册路由
  35. self.router.add_api_websocket_route("/connect", self.handle_websocket)
  36. self.router.add_api_route("/send", self.send_message, methods=["POST"])
  37. self.router.add_api_route("/status/{agent_uri:path}", self.get_agent_status, methods=["GET"])
  38. self.router.add_api_route("/agents", self.list_agents, methods=["GET"])
  39. async def handle_websocket(self, websocket: WebSocket):
  40. """
  41. 处理 WebSocket 连接
  42. Agent 通过此端点注册并保持长连接
  43. """
  44. await websocket.accept()
  45. agent_uri = None
  46. try:
  47. # 等待注册消息
  48. data = await websocket.receive_text()
  49. msg = json.loads(data)
  50. if msg.get("type") != "register":
  51. await websocket.send_text(json.dumps({
  52. "type": "error",
  53. "message": "First message must be register"
  54. }))
  55. await websocket.close()
  56. return
  57. agent_uri = msg.get("agent_uri")
  58. if not agent_uri:
  59. await websocket.send_text(json.dumps({
  60. "type": "error",
  61. "message": "agent_uri is required"
  62. }))
  63. await websocket.close()
  64. return
  65. # 注册 Agent
  66. await self.registry.register(
  67. agent_uri=agent_uri,
  68. connection_type="websocket",
  69. websocket=websocket,
  70. capabilities=msg.get("capabilities", []),
  71. metadata=msg.get("metadata", {})
  72. )
  73. # 发送注册成功消息
  74. await websocket.send_text(json.dumps({
  75. "type": "registered",
  76. "agent_uri": agent_uri
  77. }))
  78. logger.info(f"Agent connected: {agent_uri}")
  79. # 保持连接,处理消息
  80. while True:
  81. data = await websocket.receive_text()
  82. msg = json.loads(data)
  83. msg_type = msg.get("type")
  84. if msg_type == "heartbeat":
  85. # 更新心跳
  86. await self.registry.heartbeat(agent_uri)
  87. await websocket.send_text(json.dumps({
  88. "type": "heartbeat_ack"
  89. }))
  90. elif msg_type == "result":
  91. # Agent 返回任务结果
  92. # TODO: 将结果转发给调用方
  93. logger.info(f"Received result from {agent_uri}: {msg.get('task_id')}")
  94. else:
  95. logger.warning(f"Unknown message type: {msg_type}")
  96. except WebSocketDisconnect:
  97. logger.info(f"Agent disconnected: {agent_uri}")
  98. except Exception as e:
  99. logger.error(f"WebSocket error: {e}")
  100. finally:
  101. if agent_uri:
  102. await self.registry.unregister(agent_uri)
  103. async def send_message(self, request: SendMessageRequest) -> SendMessageResponse:
  104. """
  105. 发送消息到目标 Agent
  106. 通过 Gateway 路由消息
  107. """
  108. import uuid
  109. # 查找目标 Agent
  110. connection = self.registry.lookup(request.to)
  111. if not connection:
  112. raise HTTPException(status_code=404, detail=f"Agent not found: {request.to}")
  113. if not self.registry.is_online(request.to):
  114. raise HTTPException(status_code=503, detail=f"Agent offline: {request.to}")
  115. # 生成消息 ID
  116. message_id = f"msg-{uuid.uuid4()}"
  117. conversation_id = request.conversation_id or f"conv-{uuid.uuid4()}"
  118. # 构建消息
  119. message = {
  120. "type": "message",
  121. "message_id": message_id,
  122. "conversation_id": conversation_id,
  123. "from": "gateway", # TODO: 从请求中获取发送方
  124. "content": request.content,
  125. "metadata": request.metadata or {}
  126. }
  127. # 根据连接类型发送
  128. if connection.connection_type == "websocket":
  129. # 通过 WebSocket 发送
  130. await connection.websocket.send_text(json.dumps(message))
  131. status = "sent"
  132. elif connection.connection_type == "http":
  133. # 通过 HTTP 发送
  134. # TODO: 实现 HTTP 转发
  135. status = "queued"
  136. else:
  137. raise HTTPException(status_code=500, detail="Unknown connection type")
  138. return SendMessageResponse(
  139. message_id=message_id,
  140. conversation_id=conversation_id,
  141. status=status
  142. )
  143. async def get_agent_status(self, agent_uri: str) -> AgentStatusResponse:
  144. """查询 Agent 在线状态"""
  145. connection = self.registry.lookup(agent_uri)
  146. if not connection:
  147. return AgentStatusResponse(
  148. agent_uri=agent_uri,
  149. status="offline",
  150. last_seen=None
  151. )
  152. is_online = self.registry.is_online(agent_uri)
  153. return AgentStatusResponse(
  154. agent_uri=agent_uri,
  155. status="online" if is_online else "offline",
  156. last_seen=connection.last_heartbeat.isoformat() if connection else None
  157. )
  158. async def list_agents(
  159. self,
  160. connection_type: Optional[str] = None,
  161. online_only: bool = True
  162. ) -> Dict[str, Any]:
  163. """列出所有 Agent"""
  164. agents = self.registry.list_agents(
  165. connection_type=connection_type,
  166. online_only=online_only
  167. )
  168. return {
  169. "agents": [
  170. {
  171. "agent_uri": a.agent_uri,
  172. "connection_type": a.connection_type,
  173. "capabilities": a.capabilities,
  174. "registered_at": a.registered_at.isoformat(),
  175. "last_heartbeat": a.last_heartbeat.isoformat()
  176. }
  177. for a in agents
  178. ],
  179. "total": len(agents)
  180. }