瀏覽代碼

Refactor Message Type

StrayWarrior 3 周之前
父節點
當前提交
496c0c4910
共有 5 個文件被更改,包括 76 次插入72 次删除
  1. 18 16
      agent_service.py
  2. 2 1
      dialogue_manager.py
  3. 22 10
      message.py
  4. 3 3
      message_queue_backend.py
  5. 31 42
      unit_test.py

+ 18 - 16
agent_service.py

@@ -66,12 +66,13 @@ class AgentService:
             apscheduler.triggers.cron.CronTrigger(**schedule_params)
             apscheduler.triggers.cron.CronTrigger(**schedule_params)
         )
         )
 
 
-    def _get_agent_instance(self, user_id: str) -> DialogueManager:
-        """获取用户Agent实例"""
-        if user_id not in self.agent_registry:
-            self.agent_registry[user_id] = DialogueManager(
-                user_id, self.user_manager)
-        return self.agent_registry[user_id]
+    def _get_agent_instance(self, staff_id: str, user_id: str) -> DialogueManager:
+        """获取Agent实例"""
+        agent_key = 'agent_{}_{}'.format(staff_id, user_id)
+        if agent_key not in self.agent_registry:
+            self.agent_registry[agent_key] = DialogueManager(
+                staff_id, user_id, self.user_manager)
+        return self.agent_registry[agent_key]
 
 
     def process_messages(self):
     def process_messages(self):
         """持续处理接收队列消息"""
         """持续处理接收队列消息"""
@@ -101,11 +102,12 @@ class AgentService:
                                run_date=datetime.now() + timedelta(seconds=delay_sec))
                                run_date=datetime.now() + timedelta(seconds=delay_sec))
 
 
     def process_single_message(self, message: Message):
     def process_single_message(self, message: Message):
-        user_id = message.user_id
+        user_id = message.sender
+        staff_id = message.receiver
 
 
         # 获取用户信息和Agent实例
         # 获取用户信息和Agent实例
         user_profile = self.user_manager.get_user_profile(user_id)
         user_profile = self.user_manager.get_user_profile(user_id)
-        agent = self._get_agent_instance(user_id)
+        agent = self._get_agent_instance(staff_id, user_id)
 
 
         # 更新对话状态
         # 更新对话状态
         logging.debug("process message: {}".format(message))
         logging.debug("process message: {}".format(message))
@@ -131,8 +133,8 @@ class AgentService:
         self.human_queue.produce(Message.build(
         self.human_queue.produce(Message.build(
             MessageType.TEXT,
             MessageType.TEXT,
             origin_message.channel,
             origin_message.channel,
-            origin_message.staff_id,
-            origin_message.user_id,
+            origin_message.sender,
+            origin_message.receiver,
             "用户对话需人工介入,用户名:{}".format(user_id),
             "用户对话需人工介入,用户名:{}".format(user_id),
             int(time.time() * 1000)
             int(time.time() * 1000)
         ))
         ))
@@ -140,7 +142,7 @@ class AgentService:
     def _check_initiative_conversations(self):
     def _check_initiative_conversations(self):
         """定时检查主动发起对话"""
         """定时检查主动发起对话"""
         for user_id in self.user_manager.list_all_users():
         for user_id in self.user_manager.list_all_users():
-            agent = self._get_agent_instance(user_id)
+            agent = self._get_agent_instance('staff_id_0', user_id)
             should_initiate = agent.should_initiate_conversation()
             should_initiate = agent.should_initiate_conversation()
 
 
             if should_initiate:
             if should_initiate:
@@ -158,11 +160,11 @@ class AgentService:
 
 
         if response := agent.generate_response(chat_response):
         if response := agent.generate_response(chat_response):
             logging.warning("user: {}, response: {}".format(user_id, response))
             logging.warning("user: {}, response: {}".format(user_id, response))
-            self.send_queue.produce({
-                'user_id': user_id,
-                'type': MessageType.TEXT,
-                'text': response,
-            })
+            current_ts = int(time.time() * 1000)
+            self.send_queue.produce(
+                Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
+                              agent.staff_id, user_id, response, current_ts)
+            )
 
 
     def _call_chat_api(self, chat_config: Dict) -> str:
     def _call_chat_api(self, chat_config: Dict) -> str:
         if global_flags.DISABLE_LLM_API_CALL:
         if global_flags.DISABLE_LLM_API_CALL:

+ 2 - 1
dialogue_manager.py

@@ -53,7 +53,8 @@ class TimeContext(Enum):
         self.description = description
         self.description = description
 
 
 class DialogueManager:
 class DialogueManager:
-    def __init__(self, user_id: str, user_manager: UserManager):
+    def __init__(self, staff_id: str, user_id: str, user_manager: UserManager):
+        self.staff_id = staff_id
         self.user_id = user_id
         self.user_id = user_id
         self.user_manager = user_manager
         self.user_manager = user_manager
         self.current_state = DialogueState.GREETING
         self.current_state = DialogueState.GREETING

+ 22 - 10
message.py

@@ -9,12 +9,23 @@ from typing import Optional
 from pydantic import BaseModel
 from pydantic import BaseModel
 
 
 class MessageType(Enum):
 class MessageType(Enum):
+    DEFAULT = (-1, "未分类的消息")
     TEXT = (1, "文本")
     TEXT = (1, "文本")
-    AUDIO = (2, "音频")
-    IMAGE = (3, "图片")
-    VIDEO = (4, "视频")
-    MINI_PROGRAM = (5, "小程序")
-    LINK = (6, "链接")
+    VOICE = (2, "语音")
+    GIF = (3, "GIF")
+    IMAGE_GW = (4, "个微图片")
+    IMAGE_QW = (5, "企微图片")
+    MINI_PROGRAM = (6, "小程序")
+    LINK = (7, "链接")
+    SHI_PIN_HAO = (8, "视频号")
+    NAME_CARD = (9, "名片")
+    POSITION = (10, "位置")
+    RED_PACKET = (11, "红包")
+    FILE_GW = (12, "个微文件")
+    FILE_QW = (13, "企微文件")
+    VIDEO_GW = (14, "个微视频")
+    VIDEO_QW = (15, "企微视频")
+    AGGREGATION_MSG = (16, "聚合消息")
 
 
     ACTIVE_TRIGGER = (101, "主动触发器")
     ACTIVE_TRIGGER = (101, "主动触发器")
     AGGREGATION_TRIGGER = (102, "消息聚合触发器")
     AGGREGATION_TRIGGER = (102, "消息聚合触发器")
@@ -43,19 +54,20 @@ class Message(BaseModel):
      id: int
      id: int
      type: MessageType
      type: MessageType
      channel: MessageChannel
      channel: MessageChannel
-     staff_id: Optional[str] = None
-     user_id: str
+     sender: Optional[str] = None
+     receiver: str
      content: Optional[str] = None
      content: Optional[str] = None
      timestamp: int
      timestamp: int
+     ref_msg_id: Optional[int] = None
 
 
      @staticmethod
      @staticmethod
-     def build(type, channel, staff_id, user_id, content, timestamp):
+     def build(type, channel, sender, receiver, content, timestamp):
          return Message(
          return Message(
              id=0,
              id=0,
              type=type,
              type=type,
              channel=channel,
              channel=channel,
-             staff_id=staff_id,
-             user_id=user_id,
+             sender=sender,
+             receiver=receiver,
              content=content,
              content=content,
              timestamp=timestamp
              timestamp=timestamp
          )
          )

+ 3 - 3
message_queue_backend.py

@@ -3,14 +3,14 @@
 # vim:fenc=utf-8
 # vim:fenc=utf-8
 
 
 import abc
 import abc
-from typing import Dict, Any
+from typing import Dict, Any, Optional
 
 
 from message import Message
 from message import Message
 
 
 
 
 class MessageQueueBackend(abc.ABC):
 class MessageQueueBackend(abc.ABC):
     @abc.abstractmethod
     @abc.abstractmethod
-    def consume(self) -> Any:
+    def consume(self) -> Optional[Message]:
         pass
         pass
 
 
     @abc.abstractmethod
     @abc.abstractmethod
@@ -22,7 +22,7 @@ class MemoryQueueBackend(MessageQueueBackend):
     def __init__(self):
     def __init__(self):
         self._queue = []
         self._queue = []
 
 
-    def consume(self):
+    def consume(self) -> Optional[Message]:
         return self._queue.pop(0) if self._queue else None
         return self._queue.pop(0) if self._queue else None
 
 
     def produce(self, message: Message):
     def produce(self, message: Message):

+ 31 - 42
unit_test.py

@@ -7,7 +7,7 @@ from datetime import datetime, timedelta
 from typing import Dict, Optional, Tuple, Any
 from typing import Dict, Optional, Tuple, Any
 from unittest.mock import Mock, MagicMock
 from unittest.mock import Mock, MagicMock
 from agent_service import AgentService, MemoryQueueBackend
 from agent_service import AgentService, MemoryQueueBackend
-from message import MessageType
+from message import MessageType, Message, MessageChannel
 from user_manager import LocalUserManager
 from user_manager import LocalUserManager
 import time
 import time
 import logging
 import logging
@@ -49,15 +49,12 @@ def test_env():
 def test_normal_conversation_flow(test_env):
 def test_normal_conversation_flow(test_env):
     """测试正常对话流程"""
     """测试正常对话流程"""
     service, queues = test_env
     service, queues = test_env
-    service._get_agent_instance("user_id_0").message_aggregation_sec = 0
+    service._get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 0
 
 
     # 准备测试消息
     # 准备测试消息
-    test_msg = {
-        "user_id": "user_id_0",
-        "type": MessageType.TEXT,
-        "text": "你好",
-        "timestamp": int(time.time() * 1000),
-    }
+    test_msg = Message.build(
+        MessageType.TEXT, MessageChannel.CORP_WECHAT,
+        'user_id_0', 'staff_id_0', '你好', int(time.time() * 1000))
     queues.receive_queue.produce(test_msg)
     queues.receive_queue.produce(test_msg)
 
 
     # 处理消息
     # 处理消息
@@ -68,29 +65,23 @@ def test_normal_conversation_flow(test_env):
     # 验证响应消息
     # 验证响应消息
     sent_msg = queues.send_queue.consume()
     sent_msg = queues.send_queue.consume()
     assert sent_msg is not None
     assert sent_msg is not None
-    assert sent_msg["user_id"] == "user_id_0"
-    assert "模拟响应" in sent_msg["text"]
+    assert sent_msg.receiver == "user_id_0"
+    assert "模拟响应" in sent_msg.content
 
 
 def test_aggregated_conversation_flow(test_env):
 def test_aggregated_conversation_flow(test_env):
     """测试聚合对话流程"""
     """测试聚合对话流程"""
     service, queues = test_env
     service, queues = test_env
-    service._get_agent_instance("user_id_0").message_aggregation_sec = 1
+    service._get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 1
 
 
     # 准备测试消息
     # 准备测试消息
     ts_begin = int(time.time() * 1000)
     ts_begin = int(time.time() * 1000)
-    test_msg = {
-        "user_id": "user_id_0",
-        "type": MessageType.TEXT,
-        "text": "你好",
-        "timestamp": ts_begin,
-    }
+    test_msg = Message.build(
+        MessageType.TEXT, MessageChannel.CORP_WECHAT,
+        'user_id_0', 'staff_id_0', '你好', ts_begin)
     queues.receive_queue.produce(test_msg)
     queues.receive_queue.produce(test_msg)
-    test_msg = {
-        "user_id": "user_id_0",
-        "type": MessageType.TEXT,
-        "text": "我是老李",
-        "timestamp": ts_begin + 0.5 * 1000,
-    }
+    test_msg = Message.build(
+        MessageType.TEXT, MessageChannel.CORP_WECHAT,
+        'user_id_0', 'staff_id_0', '我是老李', ts_begin + 500)
     queues.receive_queue.produce(test_msg)
     queues.receive_queue.produce(test_msg)
 
 
     # 处理消息
     # 处理消息
@@ -110,29 +101,27 @@ def test_aggregated_conversation_flow(test_env):
     assert sent_msg is None
     assert sent_msg is None
 
 
     # 模拟定时器产生空消息触发响应
     # 模拟定时器产生空消息触发响应
-    service.process_single_message({
-        "user_id": "user_id_0",
-        "type": MessageType.AGGREGATION_TRIGGER,
-        "timestamp": ts_begin + 2 * 1000
-    })
+    service.process_single_message(Message.build(
+        MessageType.AGGREGATION_TRIGGER, MessageChannel.CORP_WECHAT,
+        'user_id_0', 'staff_id_0', None, ts_begin + 2000
+    ))
     # 验证第三次响应消息
     # 验证第三次响应消息
     sent_msg = queues.send_queue.consume()
     sent_msg = queues.send_queue.consume()
     assert sent_msg is not None
     assert sent_msg is not None
-    assert sent_msg["user_id"] == "user_id_0"
-    assert "模拟响应" in sent_msg["text"]
+    assert sent_msg.receiver == "user_id_0"
+    assert "模拟响应" in sent_msg.content
 
 
 def test_human_intervention_trigger(test_env):
 def test_human_intervention_trigger(test_env):
     """测试触发人工干预"""
     """测试触发人工干预"""
     service, queues = test_env
     service, queues = test_env
-    service._get_agent_instance("user_id_0").message_aggregation_sec = 0
+    service._get_agent_instance('staff_id_0',"user_id_0").message_aggregation_sec = 0
 
 
     # 准备需要人工干预的消息
     # 准备需要人工干预的消息
-    test_msg = {
-        "user_id": "user_id_0",
-        "type": MessageType.TEXT,
-        "text": "我需要帮助!",
-        "timestamp": int(time.time() * 1000),
-    }
+    test_msg = Message.build(
+        MessageType.TEXT, MessageChannel.CORP_WECHAT,
+        "user_id_0", "staff_id_0",
+        "我需要帮助!", int(time.time() * 1000)
+    )
     queues.receive_queue.produce(test_msg)
     queues.receive_queue.produce(test_msg)
 
 
     # 处理消息
     # 处理消息
@@ -143,17 +132,17 @@ def test_human_intervention_trigger(test_env):
     # 验证人工队列消息
     # 验证人工队列消息
     human_msg = queues.human_queue.consume()
     human_msg = queues.human_queue.consume()
     assert human_msg is not None
     assert human_msg is not None
-    assert human_msg["user_id"] == "user_id_0"
-    assert "state" in human_msg
+    assert human_msg.sender == "user_id_0"
+    assert "用户对话需人工介入" in human_msg.content
 
 
 def test_initiative_conversation(test_env):
 def test_initiative_conversation(test_env):
     """测试主动发起对话"""
     """测试主动发起对话"""
     service, queues = test_env
     service, queues = test_env
-    service._get_agent_instance("user_id_0").message_aggregation_sec = 0
+    service._get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 0
     service._call_chat_api = Mock(return_value="主动发起模拟消息")
     service._call_chat_api = Mock(return_value="主动发起模拟消息")
 
 
     # 设置Agent需要主动发起对话
     # 设置Agent需要主动发起对话
-    agent = service._get_agent_instance("user_id_0")
+    agent = service._get_agent_instance('staff_id_0', "user_id_0")
     agent.should_initiate_conversation = Mock(return_value=(True, MagicMock()))
     agent.should_initiate_conversation = Mock(return_value=(True, MagicMock()))
 
 
     service._check_initiative_conversations()
     service._check_initiative_conversations()
@@ -161,4 +150,4 @@ def test_initiative_conversation(test_env):
     # 验证主动发起的消息
     # 验证主动发起的消息
     sent_msg = queues.send_queue.consume()
     sent_msg = queues.send_queue.consume()
     assert sent_msg is not None
     assert sent_msg is not None
-    assert "主动发起" in sent_msg["text"]
+    assert "主动发起" in sent_msg.content