Browse Source

Refactor Message Type

StrayWarrior 3 tuần trước cách đây
mục cha
commit
496c0c4910
5 tập tin đã thay đổi với 76 bổ sung72 xóa
  1. 18 16
      agent_service.py
  2. 2 1
      dialogue_manager.py
  3. 22 10
      message.py
  4. 3 3
      message_queue_backend.py
  5. 31 42
      unit_test.py

+ 18 - 16
agent_service.py

@@ -66,12 +66,13 @@ class AgentService:
             apscheduler.triggers.cron.CronTrigger(**schedule_params)
         )
 
-    def _get_agent_instance(self, user_id: str) -> DialogueManager:
-        """获取用户Agent实例"""
-        if user_id not in self.agent_registry:
-            self.agent_registry[user_id] = DialogueManager(
-                user_id, self.user_manager)
-        return self.agent_registry[user_id]
+    def _get_agent_instance(self, staff_id: str, user_id: str) -> DialogueManager:
+        """获取Agent实例"""
+        agent_key = 'agent_{}_{}'.format(staff_id, user_id)
+        if agent_key not in self.agent_registry:
+            self.agent_registry[agent_key] = DialogueManager(
+                staff_id, user_id, self.user_manager)
+        return self.agent_registry[agent_key]
 
     def process_messages(self):
         """持续处理接收队列消息"""
@@ -101,11 +102,12 @@ class AgentService:
                                run_date=datetime.now() + timedelta(seconds=delay_sec))
 
     def process_single_message(self, message: Message):
-        user_id = message.user_id
+        user_id = message.sender
+        staff_id = message.receiver
 
         # 获取用户信息和Agent实例
         user_profile = self.user_manager.get_user_profile(user_id)
-        agent = self._get_agent_instance(user_id)
+        agent = self._get_agent_instance(staff_id, user_id)
 
         # 更新对话状态
         logging.debug("process message: {}".format(message))
@@ -131,8 +133,8 @@ class AgentService:
         self.human_queue.produce(Message.build(
             MessageType.TEXT,
             origin_message.channel,
-            origin_message.staff_id,
-            origin_message.user_id,
+            origin_message.sender,
+            origin_message.receiver,
             "用户对话需人工介入,用户名:{}".format(user_id),
             int(time.time() * 1000)
         ))
@@ -140,7 +142,7 @@ class AgentService:
     def _check_initiative_conversations(self):
         """定时检查主动发起对话"""
         for user_id in self.user_manager.list_all_users():
-            agent = self._get_agent_instance(user_id)
+            agent = self._get_agent_instance('staff_id_0', user_id)
             should_initiate = agent.should_initiate_conversation()
 
             if should_initiate:
@@ -158,11 +160,11 @@ class AgentService:
 
         if response := agent.generate_response(chat_response):
             logging.warning("user: {}, response: {}".format(user_id, response))
-            self.send_queue.produce({
-                'user_id': user_id,
-                'type': MessageType.TEXT,
-                'text': response,
-            })
+            current_ts = int(time.time() * 1000)
+            self.send_queue.produce(
+                Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
+                              agent.staff_id, user_id, response, current_ts)
+            )
 
     def _call_chat_api(self, chat_config: Dict) -> str:
         if global_flags.DISABLE_LLM_API_CALL:

+ 2 - 1
dialogue_manager.py

@@ -53,7 +53,8 @@ class TimeContext(Enum):
         self.description = description
 
 class DialogueManager:
-    def __init__(self, user_id: str, user_manager: UserManager):
+    def __init__(self, staff_id: str, user_id: str, user_manager: UserManager):
+        self.staff_id = staff_id
         self.user_id = user_id
         self.user_manager = user_manager
         self.current_state = DialogueState.GREETING

+ 22 - 10
message.py

@@ -9,12 +9,23 @@ from typing import Optional
 from pydantic import BaseModel
 
 class MessageType(Enum):
+    DEFAULT = (-1, "未分类的消息")
     TEXT = (1, "文本")
-    AUDIO = (2, "音频")
-    IMAGE = (3, "图片")
-    VIDEO = (4, "视频")
-    MINI_PROGRAM = (5, "小程序")
-    LINK = (6, "链接")
+    VOICE = (2, "语音")
+    GIF = (3, "GIF")
+    IMAGE_GW = (4, "个微图片")
+    IMAGE_QW = (5, "企微图片")
+    MINI_PROGRAM = (6, "小程序")
+    LINK = (7, "链接")
+    SHI_PIN_HAO = (8, "视频号")
+    NAME_CARD = (9, "名片")
+    POSITION = (10, "位置")
+    RED_PACKET = (11, "红包")
+    FILE_GW = (12, "个微文件")
+    FILE_QW = (13, "企微文件")
+    VIDEO_GW = (14, "个微视频")
+    VIDEO_QW = (15, "企微视频")
+    AGGREGATION_MSG = (16, "聚合消息")
 
     ACTIVE_TRIGGER = (101, "主动触发器")
     AGGREGATION_TRIGGER = (102, "消息聚合触发器")
@@ -43,19 +54,20 @@ class Message(BaseModel):
      id: int
      type: MessageType
      channel: MessageChannel
-     staff_id: Optional[str] = None
-     user_id: str
+     sender: Optional[str] = None
+     receiver: str
      content: Optional[str] = None
      timestamp: int
+     ref_msg_id: Optional[int] = None
 
      @staticmethod
-     def build(type, channel, staff_id, user_id, content, timestamp):
+     def build(type, channel, sender, receiver, content, timestamp):
          return Message(
              id=0,
              type=type,
              channel=channel,
-             staff_id=staff_id,
-             user_id=user_id,
+             sender=sender,
+             receiver=receiver,
              content=content,
              timestamp=timestamp
          )

+ 3 - 3
message_queue_backend.py

@@ -3,14 +3,14 @@
 # vim:fenc=utf-8
 
 import abc
-from typing import Dict, Any
+from typing import Dict, Any, Optional
 
 from message import Message
 
 
 class MessageQueueBackend(abc.ABC):
     @abc.abstractmethod
-    def consume(self) -> Any:
+    def consume(self) -> Optional[Message]:
         pass
 
     @abc.abstractmethod
@@ -22,7 +22,7 @@ class MemoryQueueBackend(MessageQueueBackend):
     def __init__(self):
         self._queue = []
 
-    def consume(self):
+    def consume(self) -> Optional[Message]:
         return self._queue.pop(0) if self._queue else None
 
     def produce(self, message: Message):

+ 31 - 42
unit_test.py

@@ -7,7 +7,7 @@ from datetime import datetime, timedelta
 from typing import Dict, Optional, Tuple, Any
 from unittest.mock import Mock, MagicMock
 from agent_service import AgentService, MemoryQueueBackend
-from message import MessageType
+from message import MessageType, Message, MessageChannel
 from user_manager import LocalUserManager
 import time
 import logging
@@ -49,15 +49,12 @@ def test_env():
 def test_normal_conversation_flow(test_env):
     """测试正常对话流程"""
     service, queues = test_env
-    service._get_agent_instance("user_id_0").message_aggregation_sec = 0
+    service._get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 0
 
     # 准备测试消息
-    test_msg = {
-        "user_id": "user_id_0",
-        "type": MessageType.TEXT,
-        "text": "你好",
-        "timestamp": int(time.time() * 1000),
-    }
+    test_msg = Message.build(
+        MessageType.TEXT, MessageChannel.CORP_WECHAT,
+        'user_id_0', 'staff_id_0', '你好', int(time.time() * 1000))
     queues.receive_queue.produce(test_msg)
 
     # 处理消息
@@ -68,29 +65,23 @@ def test_normal_conversation_flow(test_env):
     # 验证响应消息
     sent_msg = queues.send_queue.consume()
     assert sent_msg is not None
-    assert sent_msg["user_id"] == "user_id_0"
-    assert "模拟响应" in sent_msg["text"]
+    assert sent_msg.receiver == "user_id_0"
+    assert "模拟响应" in sent_msg.content
 
 def test_aggregated_conversation_flow(test_env):
     """测试聚合对话流程"""
     service, queues = test_env
-    service._get_agent_instance("user_id_0").message_aggregation_sec = 1
+    service._get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 1
 
     # 准备测试消息
     ts_begin = int(time.time() * 1000)
-    test_msg = {
-        "user_id": "user_id_0",
-        "type": MessageType.TEXT,
-        "text": "你好",
-        "timestamp": ts_begin,
-    }
+    test_msg = Message.build(
+        MessageType.TEXT, MessageChannel.CORP_WECHAT,
+        'user_id_0', 'staff_id_0', '你好', ts_begin)
     queues.receive_queue.produce(test_msg)
-    test_msg = {
-        "user_id": "user_id_0",
-        "type": MessageType.TEXT,
-        "text": "我是老李",
-        "timestamp": ts_begin + 0.5 * 1000,
-    }
+    test_msg = Message.build(
+        MessageType.TEXT, MessageChannel.CORP_WECHAT,
+        'user_id_0', 'staff_id_0', '我是老李', ts_begin + 500)
     queues.receive_queue.produce(test_msg)
 
     # 处理消息
@@ -110,29 +101,27 @@ def test_aggregated_conversation_flow(test_env):
     assert sent_msg is None
 
     # 模拟定时器产生空消息触发响应
-    service.process_single_message({
-        "user_id": "user_id_0",
-        "type": MessageType.AGGREGATION_TRIGGER,
-        "timestamp": ts_begin + 2 * 1000
-    })
+    service.process_single_message(Message.build(
+        MessageType.AGGREGATION_TRIGGER, MessageChannel.CORP_WECHAT,
+        'user_id_0', 'staff_id_0', None, ts_begin + 2000
+    ))
     # 验证第三次响应消息
     sent_msg = queues.send_queue.consume()
     assert sent_msg is not None
-    assert sent_msg["user_id"] == "user_id_0"
-    assert "模拟响应" in sent_msg["text"]
+    assert sent_msg.receiver == "user_id_0"
+    assert "模拟响应" in sent_msg.content
 
 def test_human_intervention_trigger(test_env):
     """测试触发人工干预"""
     service, queues = test_env
-    service._get_agent_instance("user_id_0").message_aggregation_sec = 0
+    service._get_agent_instance('staff_id_0',"user_id_0").message_aggregation_sec = 0
 
     # 准备需要人工干预的消息
-    test_msg = {
-        "user_id": "user_id_0",
-        "type": MessageType.TEXT,
-        "text": "我需要帮助!",
-        "timestamp": int(time.time() * 1000),
-    }
+    test_msg = Message.build(
+        MessageType.TEXT, MessageChannel.CORP_WECHAT,
+        "user_id_0", "staff_id_0",
+        "我需要帮助!", int(time.time() * 1000)
+    )
     queues.receive_queue.produce(test_msg)
 
     # 处理消息
@@ -143,17 +132,17 @@ def test_human_intervention_trigger(test_env):
     # 验证人工队列消息
     human_msg = queues.human_queue.consume()
     assert human_msg is not None
-    assert human_msg["user_id"] == "user_id_0"
-    assert "state" in human_msg
+    assert human_msg.sender == "user_id_0"
+    assert "用户对话需人工介入" in human_msg.content
 
 def test_initiative_conversation(test_env):
     """测试主动发起对话"""
     service, queues = test_env
-    service._get_agent_instance("user_id_0").message_aggregation_sec = 0
+    service._get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 0
     service._call_chat_api = Mock(return_value="主动发起模拟消息")
 
     # 设置Agent需要主动发起对话
-    agent = service._get_agent_instance("user_id_0")
+    agent = service._get_agent_instance('staff_id_0', "user_id_0")
     agent.should_initiate_conversation = Mock(return_value=(True, MagicMock()))
 
     service._check_initiative_conversations()
@@ -161,4 +150,4 @@ def test_initiative_conversation(test_env):
     # 验证主动发起的消息
     sent_msg = queues.send_queue.consume()
     assert sent_msg is not None
-    assert "主动发起" in sent_msg["text"]
+    assert "主动发起" in sent_msg.content