tools.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. """IM Agent Tools — 供 Agent 在 tool-use loop 中调用的工具函数。
  2. 新架构:一个 Agent (contact_id) = 一个 IMClient 实例,该实例管理多个窗口 (chat_id)。
  3. 使用方式:
  4. 1. 调用 setup(contact_id) 初始化 Agent 的 IMClient
  5. 2. 调用 open_window(contact_id, chat_id) 打开窗口
  6. 3. 在每轮 loop 中调用 check_notification(contact_id, chat_id) 检查该窗口的新消息
  7. 4. 有通知时调用 receive_messages(contact_id, chat_id) 读取消息
  8. 5. 发消息调用 send_message(contact_id, chat_id, receiver, content)
  9. """
  10. import asyncio
  11. import httpx
  12. from client import IMClient
  13. from notifier import AgentNotifier
  14. # ── 全局状态 ──
  15. _clients: dict[str, IMClient] = {}
  16. _tasks: dict[str, asyncio.Task] = {}
  17. _notifications: dict[tuple[str, str], dict] = {} # (contact_id, chat_id) -> 通知
  18. class _ToolNotifier(AgentNotifier):
  19. """内部通知器:按 (contact_id, chat_id) 分发通知。"""
  20. def __init__(self, contact_id: str, chat_id: str):
  21. self._key = (contact_id, chat_id)
  22. async def notify(self, count: int, from_contacts: list[str]):
  23. _notifications[self._key] = {"count": count, "from": from_contacts}
  24. # ── Tool 1: 初始化 Agent ──
  25. def setup(contact_id: str, server_url: str = "ws://localhost:8000", notify_interval: float = 10.0) -> str:
  26. """初始化一个 Agent 的 IMClient(一个实例管理多个窗口)。
  27. Args:
  28. contact_id: Agent 的身份 ID
  29. server_url: Server 地址
  30. notify_interval: 检查新消息的间隔秒数
  31. Returns:
  32. 状态描述
  33. """
  34. if contact_id in _clients:
  35. return f"已连接: {contact_id}"
  36. client = IMClient(contact_id=contact_id, server_url=server_url, notify_interval=notify_interval)
  37. _clients[contact_id] = client
  38. loop = asyncio.get_event_loop()
  39. _tasks[contact_id] = loop.create_task(client.run())
  40. return f"已启动 IM Client: {contact_id}"
  41. def teardown(contact_id: str) -> str:
  42. """停止并移除一个 Agent 的 IMClient。"""
  43. task = _tasks.pop(contact_id, None)
  44. if task:
  45. task.cancel()
  46. _clients.pop(contact_id, None)
  47. # 清理该 contact_id 的所有通知
  48. keys_to_remove = [k for k in _notifications if k[0] == contact_id]
  49. for k in keys_to_remove:
  50. _notifications.pop(k, None)
  51. return f"已停止: {contact_id}"
  52. # ── Tool 2: 窗口管理 ──
  53. def open_window(contact_id: str, chat_id: str | None = None) -> str:
  54. """为某个 Agent 打开一个新窗口。
  55. Args:
  56. contact_id: Agent ID
  57. chat_id: 窗口 ID(留空自动生成)
  58. Returns:
  59. 窗口的 chat_id
  60. """
  61. client = _clients.get(contact_id)
  62. if client is None:
  63. return f"错误: {contact_id} 未初始化"
  64. actual_chat_id = client.open_window(chat_id=chat_id, notifier=_ToolNotifier(contact_id, chat_id or ""))
  65. # 更新 notifier 的 chat_id
  66. if chat_id is None:
  67. client._notifiers[actual_chat_id] = _ToolNotifier(contact_id, actual_chat_id)
  68. return actual_chat_id
  69. def close_window(contact_id: str, chat_id: str) -> str:
  70. """关闭某个窗口。"""
  71. client = _clients.get(contact_id)
  72. if client is None:
  73. return f"错误: {contact_id} 未初始化"
  74. client.close_window(chat_id)
  75. _notifications.pop((contact_id, chat_id), None)
  76. return f"已关闭窗口: {chat_id}"
  77. def list_windows(contact_id: str) -> list[str]:
  78. """列出某个 Agent 的所有窗口。"""
  79. client = _clients.get(contact_id)
  80. if client is None:
  81. return []
  82. return client.list_windows()
  83. # ── Tool 3: 检查通知 ──
  84. def check_notification(contact_id: str, chat_id: str) -> dict | None:
  85. """检查某个窗口是否有新消息通知。
  86. Returns:
  87. 有新消息: {"count": 3, "from": ["alice", "bob"]}
  88. 没有新消息: None
  89. """
  90. return _notifications.pop((contact_id, chat_id), None)
  91. # ── Tool 4: 接收消息 ──
  92. def receive_messages(contact_id: str, chat_id: str) -> list[dict]:
  93. """读取某个窗口的待处理消息,读取后自动清空。
  94. Returns:
  95. 消息列表,每条格式:
  96. {
  97. "sender": "alice",
  98. "sender_chat_id": "...",
  99. "content": "你好",
  100. "msg_type": "chat"
  101. }
  102. """
  103. client = _clients.get(contact_id)
  104. if client is None:
  105. return []
  106. raw = client.read_pending(chat_id)
  107. return [
  108. {
  109. "sender": m.get("sender", "unknown"),
  110. "sender_chat_id": m.get("sender_chat_id"),
  111. "content": m.get("content", ""),
  112. "msg_type": m.get("msg_type", "chat"),
  113. }
  114. for m in raw
  115. ]
  116. # ── Tool 5: 发送消息 ──
  117. def send_message(
  118. contact_id: str,
  119. chat_id: str,
  120. receiver: str,
  121. content: str,
  122. msg_type: str = "chat",
  123. receiver_chat_id: str | None = None,
  124. ) -> str:
  125. """从某个窗口发送消息。
  126. Args:
  127. contact_id: 发送方 Agent ID
  128. chat_id: 发送方窗口 ID
  129. receiver: 接收方 contact_id
  130. content: 消息内容
  131. msg_type: 消息类型
  132. receiver_chat_id: 接收方窗口 ID(不指定则广播)
  133. Returns:
  134. 状态描述
  135. """
  136. client = _clients.get(contact_id)
  137. if client is None:
  138. return f"错误: {contact_id} 未初始化"
  139. client.send_message(chat_id, receiver, content, msg_type, receiver_chat_id)
  140. target = f"{receiver}:{receiver_chat_id}" if receiver_chat_id else f"{receiver}:*"
  141. return f"[{contact_id}:{chat_id}] 已发送给 {target}: {content[:50]}"
  142. # ── Tool 6: 查询联系人 ──
  143. def get_contacts(contact_id: str, server_http_url: str = "http://localhost:8000") -> dict:
  144. """查询某个 Agent 的联系人列表和在线用户。"""
  145. if contact_id not in _clients:
  146. return {"error": f"{contact_id} 未初始化"}
  147. result = {}
  148. with httpx.Client() as http:
  149. try:
  150. r = http.get(f"{server_http_url}/contacts/{contact_id}")
  151. result["contacts"] = r.json().get("contacts", [])
  152. except Exception as e:
  153. result["contacts_error"] = str(e)
  154. try:
  155. r = http.get(f"{server_http_url}/health")
  156. result["online"] = r.json().get("online", {})
  157. except Exception as e:
  158. result["online_error"] = str(e)
  159. return result
  160. # ── Tool 7: 查询聊天历史 ──
  161. def get_chat_history(contact_id: str, chat_id: str, peer_id: str | None = None, limit: int = 20) -> list[dict]:
  162. """查询某个窗口的聊天历史。"""
  163. client = _clients.get(contact_id)
  164. if client is None:
  165. return []
  166. return client.get_chat_history(chat_id, peer_id, limit)