| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222 |
- """
- Gateway Router
- 消息路由、在线状态查询
- """
- import json
- import logging
- from typing import Dict, Any, Optional
- from fastapi import APIRouter, WebSocket, WebSocketDisconnect, HTTPException
- from pydantic import BaseModel
- from .registry import AgentRegistry
- logger = logging.getLogger(__name__)
- class SendMessageRequest(BaseModel):
- """发送消息请求"""
- to: str # 目标 Agent URI
- content: Any # 消息内容(字符串或多模态数组)
- conversation_id: Optional[str] = None
- metadata: Optional[Dict] = None
- class SendMessageResponse(BaseModel):
- """发送消息响应"""
- message_id: str
- conversation_id: str
- status: str # "sent" | "queued"
- class AgentStatusResponse(BaseModel):
- """Agent 状态响应"""
- agent_uri: str
- status: str # "online" | "offline"
- last_seen: Optional[str] = None
- class GatewayRouter:
- """Gateway 路由器"""
- def __init__(self, registry: AgentRegistry):
- self.registry = registry
- self.router = APIRouter(prefix="/gateway", tags=["gateway"])
- # 注册路由
- self.router.add_api_websocket_route("/connect", self.handle_websocket)
- self.router.add_api_route("/send", self.send_message, methods=["POST"])
- self.router.add_api_route("/status/{agent_uri:path}", self.get_agent_status, methods=["GET"])
- self.router.add_api_route("/agents", self.list_agents, methods=["GET"])
- async def handle_websocket(self, websocket: WebSocket):
- """
- 处理 WebSocket 连接
- Agent 通过此端点注册并保持长连接
- """
- await websocket.accept()
- agent_uri = None
- try:
- # 等待注册消息
- data = await websocket.receive_text()
- msg = json.loads(data)
- if msg.get("type") != "register":
- await websocket.send_text(json.dumps({
- "type": "error",
- "message": "First message must be register"
- }))
- await websocket.close()
- return
- agent_uri = msg.get("agent_uri")
- if not agent_uri:
- await websocket.send_text(json.dumps({
- "type": "error",
- "message": "agent_uri is required"
- }))
- await websocket.close()
- return
- # 注册 Agent
- await self.registry.register(
- agent_uri=agent_uri,
- connection_type="websocket",
- websocket=websocket,
- capabilities=msg.get("capabilities", []),
- metadata=msg.get("metadata", {})
- )
- # 发送注册成功消息
- await websocket.send_text(json.dumps({
- "type": "registered",
- "agent_uri": agent_uri
- }))
- logger.info(f"Agent connected: {agent_uri}")
- # 保持连接,处理消息
- while True:
- data = await websocket.receive_text()
- msg = json.loads(data)
- msg_type = msg.get("type")
- if msg_type == "heartbeat":
- # 更新心跳
- await self.registry.heartbeat(agent_uri)
- await websocket.send_text(json.dumps({
- "type": "heartbeat_ack"
- }))
- elif msg_type == "result":
- # Agent 返回任务结果
- # TODO: 将结果转发给调用方
- logger.info(f"Received result from {agent_uri}: {msg.get('task_id')}")
- else:
- logger.warning(f"Unknown message type: {msg_type}")
- except WebSocketDisconnect:
- logger.info(f"Agent disconnected: {agent_uri}")
- except Exception as e:
- logger.error(f"WebSocket error: {e}")
- finally:
- if agent_uri:
- await self.registry.unregister(agent_uri)
- async def send_message(self, request: SendMessageRequest) -> SendMessageResponse:
- """
- 发送消息到目标 Agent
- 通过 Gateway 路由消息
- """
- import uuid
- # 查找目标 Agent
- connection = self.registry.lookup(request.to)
- if not connection:
- raise HTTPException(status_code=404, detail=f"Agent not found: {request.to}")
- if not self.registry.is_online(request.to):
- raise HTTPException(status_code=503, detail=f"Agent offline: {request.to}")
- # 生成消息 ID
- message_id = f"msg-{uuid.uuid4()}"
- conversation_id = request.conversation_id or f"conv-{uuid.uuid4()}"
- # 构建消息
- message = {
- "type": "message",
- "message_id": message_id,
- "conversation_id": conversation_id,
- "from": "gateway", # TODO: 从请求中获取发送方
- "content": request.content,
- "metadata": request.metadata or {}
- }
- # 根据连接类型发送
- if connection.connection_type == "websocket":
- # 通过 WebSocket 发送
- await connection.websocket.send_text(json.dumps(message))
- status = "sent"
- elif connection.connection_type == "http":
- # 通过 HTTP 发送
- # TODO: 实现 HTTP 转发
- status = "queued"
- else:
- raise HTTPException(status_code=500, detail="Unknown connection type")
- return SendMessageResponse(
- message_id=message_id,
- conversation_id=conversation_id,
- status=status
- )
- async def get_agent_status(self, agent_uri: str) -> AgentStatusResponse:
- """查询 Agent 在线状态"""
- connection = self.registry.lookup(agent_uri)
- if not connection:
- return AgentStatusResponse(
- agent_uri=agent_uri,
- status="offline",
- last_seen=None
- )
- is_online = self.registry.is_online(agent_uri)
- return AgentStatusResponse(
- agent_uri=agent_uri,
- status="online" if is_online else "offline",
- last_seen=connection.last_heartbeat.isoformat() if connection else None
- )
- async def list_agents(
- self,
- connection_type: Optional[str] = None,
- online_only: bool = True
- ) -> Dict[str, Any]:
- """列出所有 Agent"""
- agents = self.registry.list_agents(
- connection_type=connection_type,
- online_only=online_only
- )
- return {
- "agents": [
- {
- "agent_uri": a.agent_uri,
- "connection_type": a.connection_type,
- "capabilities": a.capabilities,
- "registered_at": a.registered_at.isoformat(),
- "last_heartbeat": a.last_heartbeat.isoformat()
- }
- for a in agents
- ],
- "total": len(agents)
- }
|