client.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. import asyncio
  2. import json
  3. import logging
  4. import os
  5. import tempfile
  6. import uuid
  7. from datetime import datetime
  8. from pathlib import Path
  9. import websockets
  10. from filelock import FileLock
  11. from protocol import IMMessage, IMResponse
  12. from notifier import AgentNotifier, ConsoleNotifier
  13. logging.basicConfig(level=logging.INFO, format="%(asctime)s [CLIENT:%(name)s] %(message)s")
  14. class ChatWindow:
  15. """单个聊天窗口的数据管理。"""
  16. def __init__(self, chat_id: str, data_dir: Path):
  17. self.chat_id = chat_id
  18. self.data_dir = data_dir
  19. self.data_dir.mkdir(parents=True, exist_ok=True)
  20. self.chatbox_path = data_dir / "chatbox.jsonl"
  21. self.in_pending_path = data_dir / "in_pending.json"
  22. self.out_pending_path = data_dir / "out_pending.jsonl"
  23. # 文件锁
  24. self._in_pending_lock = FileLock(str(data_dir / ".in_pending.lock"))
  25. self._out_pending_lock = FileLock(str(data_dir / ".out_pending.lock"))
  26. self._chatbox_lock = FileLock(str(data_dir / ".chatbox.lock"))
  27. # 初始化文件
  28. if not self.chatbox_path.exists():
  29. self.chatbox_path.write_text("")
  30. if not self.in_pending_path.exists():
  31. self.in_pending_path.write_text("[]")
  32. if not self.out_pending_path.exists():
  33. self.out_pending_path.write_text("")
  34. def append_to_in_pending(self, msg: dict):
  35. with self._in_pending_lock:
  36. pending = self._load_json_array(self.in_pending_path)
  37. pending.append(msg)
  38. self._atomic_write_json(self.in_pending_path, pending)
  39. def read_in_pending(self) -> list[dict]:
  40. with self._in_pending_lock:
  41. return self._load_json_array(self.in_pending_path)
  42. def clear_in_pending(self):
  43. with self._in_pending_lock:
  44. self._atomic_write_json(self.in_pending_path, [])
  45. def append_to_chatbox(self, msg: dict):
  46. with self._chatbox_lock:
  47. with open(self.chatbox_path, "a", encoding="utf-8") as f:
  48. f.write(json.dumps(msg, ensure_ascii=False) + "\n")
  49. def append_to_out_pending(self, msg: dict):
  50. with self._out_pending_lock:
  51. with open(self.out_pending_path, "a", encoding="utf-8") as f:
  52. f.write(json.dumps(msg, ensure_ascii=False) + "\n")
  53. @staticmethod
  54. def _load_json_array(path: Path) -> list:
  55. if not path.exists():
  56. return []
  57. text = path.read_text(encoding="utf-8").strip()
  58. if not text:
  59. return []
  60. try:
  61. data = json.loads(text)
  62. return data if isinstance(data, list) else []
  63. except json.JSONDecodeError:
  64. return []
  65. @staticmethod
  66. def _atomic_write_json(path: Path, data):
  67. tmp_fd, tmp_path = tempfile.mkstemp(dir=str(path.parent), suffix=".tmp")
  68. try:
  69. with os.fdopen(tmp_fd, "w", encoding="utf-8") as f:
  70. json.dump(data, f, ensure_ascii=False, indent=2)
  71. os.replace(tmp_path, str(path))
  72. except Exception:
  73. if os.path.exists(tmp_path):
  74. os.unlink(tmp_path)
  75. raise
  76. class IMClient:
  77. """IM Client - 一个实例管理多个聊天窗口。
  78. 一个 Agent (contact_id) 对应一个 IMClient 实例。
  79. 该实例可以管理多个 chat_id(窗口),每个窗口有独立的消息存储。
  80. """
  81. def __init__(
  82. self,
  83. contact_id: str,
  84. server_url: str = "ws://localhost:8000",
  85. data_dir: str | None = None,
  86. notify_interval: float = 30.0,
  87. ):
  88. self.contact_id = contact_id
  89. self.server_url = server_url
  90. self.notify_interval = notify_interval
  91. self.base_dir = Path(data_dir) if data_dir else Path("data") / contact_id
  92. self.base_dir.mkdir(parents=True, exist_ok=True)
  93. # 窗口管理
  94. self._windows: dict[str, ChatWindow] = {}
  95. self._notifiers: dict[str, AgentNotifier] = {}
  96. self.ws = None
  97. self.log = logging.getLogger(contact_id)
  98. self._send_queue = asyncio.Queue()
  99. def open_window(self, chat_id: str | None = None, notifier: AgentNotifier | None = None) -> str:
  100. """打开一个新窗口。
  101. Args:
  102. chat_id: 窗口 ID(留空自动生成)
  103. notifier: 该窗口的通知器
  104. Returns:
  105. 窗口的 chat_id
  106. """
  107. if chat_id is None:
  108. chat_id = datetime.now().strftime("%Y%m%d_%H%M%S_") + uuid.uuid4().hex[:6]
  109. if chat_id in self._windows:
  110. return chat_id
  111. window_dir = self.base_dir / "windows" / chat_id
  112. self._windows[chat_id] = ChatWindow(chat_id, window_dir)
  113. self._notifiers[chat_id] = notifier or ConsoleNotifier()
  114. self.log.info(f"打开窗口: {chat_id}")
  115. return chat_id
  116. def close_window(self, chat_id: str):
  117. """关闭一个窗口。"""
  118. self._windows.pop(chat_id, None)
  119. self._notifiers.pop(chat_id, None)
  120. self.log.info(f"关闭窗口: {chat_id}")
  121. def list_windows(self) -> list[str]:
  122. """列出所有打开的窗口。"""
  123. return list(self._windows.keys())
  124. async def run(self):
  125. """启动 Client 服务,自动重连。"""
  126. while True:
  127. try:
  128. # 连接时不带 chat_id,因为一个实例管理多个窗口
  129. ws_url = f"{self.server_url}/ws?contact_id={self.contact_id}&chat_id=__multi__"
  130. self.log.info(f"连接 {ws_url} ...")
  131. async with websockets.connect(ws_url) as ws:
  132. self.ws = ws
  133. self.log.info("已连接")
  134. await asyncio.gather(
  135. self._ws_listener(),
  136. self._send_worker(),
  137. self._pending_notifier(),
  138. )
  139. except (websockets.ConnectionClosed, ConnectionRefusedError, OSError) as e:
  140. self.log.warning(f"连接断开: {e}, 5 秒后重连...")
  141. self.ws = None
  142. await asyncio.sleep(5)
  143. except asyncio.CancelledError:
  144. self.log.info("服务停止")
  145. break
  146. async def _ws_listener(self):
  147. """监听 WebSocket,根据 receiver_chat_id 分发到对应窗口。"""
  148. async for raw in self.ws:
  149. try:
  150. data = json.loads(raw)
  151. except json.JSONDecodeError:
  152. self.log.warning(f"收到无效 JSON: {raw}")
  153. continue
  154. if "sender" in data and "receiver" in data:
  155. # 聊天消息
  156. receiver_chat_id = data.get("receiver_chat_id")
  157. if receiver_chat_id and receiver_chat_id in self._windows:
  158. # 定向发送到指定窗口
  159. window = self._windows[receiver_chat_id]
  160. window.append_to_in_pending(data)
  161. window.append_to_chatbox(data)
  162. self.log.info(f"收到消息 -> 窗口 {receiver_chat_id}: {data['sender']}")
  163. elif not receiver_chat_id:
  164. # 广播到所有窗口
  165. for chat_id, window in self._windows.items():
  166. window.append_to_in_pending(data)
  167. window.append_to_chatbox(data)
  168. self.log.info(f"收到消息 -> 广播到 {len(self._windows)} 个窗口: {data['sender']}")
  169. else:
  170. self.log.warning(f"收到消息但窗口 {receiver_chat_id} 不存在")
  171. elif "status" in data:
  172. # 发送回执
  173. resp = IMResponse(**data)
  174. if resp.status == "success":
  175. self.log.info(f"消息 {resp.msg_id} 发送成功")
  176. else:
  177. self.log.warning(f"消息 {resp.msg_id} 发送失败: {resp.error}")
  178. async def _send_worker(self):
  179. """从队列取消息并发送。"""
  180. while True:
  181. msg_data = await self._send_queue.get()
  182. msg = IMMessage(sender=self.contact_id, **msg_data)
  183. try:
  184. await self.ws.send(msg.model_dump_json())
  185. self.log.info(f"发送消息: -> {msg.receiver}:{msg.receiver_chat_id or '*'}")
  186. # 记录到发送方窗口的 chatbox
  187. if msg.sender_chat_id and msg.sender_chat_id in self._windows:
  188. self._windows[msg.sender_chat_id].append_to_chatbox(msg.model_dump())
  189. except Exception as e:
  190. self.log.error(f"发送失败: {e}")
  191. if msg.sender_chat_id and msg.sender_chat_id in self._windows:
  192. self._windows[msg.sender_chat_id].append_to_out_pending(msg.model_dump())
  193. async def _pending_notifier(self):
  194. """轮询各窗口的 in_pending,有新消息就调通知回调。"""
  195. while True:
  196. for chat_id, window in list(self._windows.items()):
  197. pending = window.read_in_pending()
  198. if pending:
  199. senders = list(set(m.get("sender", "unknown") for m in pending))
  200. count = len(pending)
  201. notifier = self._notifiers.get(chat_id)
  202. if notifier:
  203. try:
  204. await notifier.notify(count=count, from_contacts=senders)
  205. except Exception as e:
  206. self.log.error(f"窗口 {chat_id} 通知回调异常: {e}")
  207. await asyncio.sleep(self.notify_interval)
  208. # ── Agent 调用的工具方法 ──
  209. def read_pending(self, chat_id: str) -> list[dict]:
  210. """读取某个窗口的待处理消息,并清空。"""
  211. window = self._windows.get(chat_id)
  212. if window is None:
  213. return []
  214. pending = window.read_in_pending()
  215. if pending:
  216. window.clear_in_pending()
  217. return pending
  218. def send_message(
  219. self,
  220. chat_id: str,
  221. receiver: str,
  222. content: str,
  223. msg_type: str = "chat",
  224. receiver_chat_id: str | None = None,
  225. ):
  226. """从某个窗口发送消息。"""
  227. msg_data = {
  228. "sender_chat_id": chat_id,
  229. "receiver": receiver,
  230. "content": content,
  231. "msg_type": msg_type,
  232. "receiver_chat_id": receiver_chat_id,
  233. }
  234. self._send_queue.put_nowait(msg_data)
  235. def get_chat_history(self, chat_id: str, peer_id: str | None = None, limit: int = 20) -> list[dict]:
  236. """查询某个窗口的聊天历史。"""
  237. window = self._windows.get(chat_id)
  238. if window is None or not window.chatbox_path.exists():
  239. return []
  240. lines = window.chatbox_path.read_text(encoding="utf-8").strip().splitlines()
  241. messages = []
  242. for line in reversed(lines):
  243. if not line.strip():
  244. continue
  245. try:
  246. m = json.loads(line)
  247. except json.JSONDecodeError:
  248. continue
  249. if peer_id and m.get("sender") != peer_id and m.get("receiver") != peer_id:
  250. continue
  251. messages.append({
  252. "sender": m.get("sender", "unknown"),
  253. "receiver": m.get("receiver", "unknown"),
  254. "content": m.get("content", ""),
  255. "msg_type": m.get("msg_type", "chat"),
  256. })
  257. if len(messages) >= limit:
  258. break
  259. messages.reverse()
  260. return messages