Просмотр исходного кода

Refactor: rename Message to MqMessage

StrayWarrior 2 недель назад
Родитель
Сommit
2343486149

+ 9 - 9
pqai_agent/agent_service.py

@@ -31,7 +31,7 @@ from pqai_agent.response_type_detector import ResponseTypeDetector
 from pqai_agent.user_manager import UserManager, UserRelationManager
 from pqai_agent.message_queue_backend import MessageQueueBackend, AliyunRocketMQQueueBackend
 from pqai_agent.user_profile_extractor import UserProfileExtractor
-from pqai_agent.mq_message import MessageType, Message, MessageChannel
+from pqai_agent.mq_message import MessageType, MqMessage, MessageChannel
 from pqai_agent.utils.db_utils import create_sql_engine
 
 
@@ -136,7 +136,7 @@ class AgentService:
             time.sleep(1)
         logger.info("Scheduler event processing thread exit")
 
-    def process_scheduler_event(self, msg: Message):
+    def process_scheduler_event(self, msg: MqMessage):
         if msg.type == MessageType.AGGREGATION_TRIGGER:
             # 延迟触发的消息,需放入接收队列以驱动Agent运转
             self.receive_queue.produce(msg)
@@ -190,7 +190,7 @@ class AgentService:
                     logger.error("Error processing message: {}, {}".format(e, error_stack))
             time.sleep(0.1)
         receive_queue.shutdown()
-        logger.info("Message processing thread exit")
+        logger.info("MqMessage processing thread exit")
 
     def start(self, blocking=False):
         self.running = True
@@ -255,7 +255,7 @@ class AgentService:
     def _schedule_aggregation_trigger(self, staff_id: str, user_id: str, delay_sec: int):
         logger.debug("user: {}, schedule trigger message after {} seconds".format(user_id, delay_sec))
         message_ts = int((time.time() + delay_sec) * 1000)
-        msg = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.SYSTEM, user_id, staff_id, None, message_ts)
+        msg = MqMessage.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.SYSTEM, user_id, staff_id, None, message_ts)
         # 系统消息使用特定的msgId,无实际意义
         msg.msgId = -MessageType.AGGREGATION_TRIGGER.value
         if self.scheduler_mode == 'mq':
@@ -265,7 +265,7 @@ class AgentService:
                                    'date',
                                    run_date=datetime.now() + timedelta(seconds=delay_sec))
 
-    def process_single_message(self, message: Message):
+    def process_single_message(self, message: MqMessage):
         user_id = message.sender
         staff_id = message.receiver
 
@@ -325,13 +325,13 @@ class AgentService:
             return
         self.send_rate_limiter.wait_for_sending(staff_id, response)
         self.send_queue.produce(
-            Message.build(message_type, MessageChannel.CORP_WECHAT,
-                          staff_id, user_id, response, current_ts)
+            MqMessage.build(message_type, MessageChannel.CORP_WECHAT,
+                            staff_id, user_id, response, current_ts)
         )
 
-    def _route_to_human_intervention(self, user_id: str, origin_message: Message):
+    def _route_to_human_intervention(self, user_id: str, origin_message: MqMessage):
         """路由到人工干预"""
-        self.human_queue.produce(Message.build(
+        self.human_queue.produce(MqMessage.build(
             MessageType.TEXT,
             origin_message.channel,
             origin_message.sender,

+ 2 - 2
pqai_agent/dialogue_manager.py

@@ -21,7 +21,7 @@ from pqai_agent import chat_service, prompt_templates
 from pqai_agent.history_dialogue_service import HistoryDialogueService
 
 from pqai_agent.chat_service import ChatServiceType
-from pqai_agent.mq_message import MessageType, Message
+from pqai_agent.mq_message import MessageType, MqMessage
 from pqai_agent.toolkit.lark_alert_for_human_intervention import LarkAlertForHumanIntervention
 from pqai_agent.toolkit.lark_sheet_record_for_human_intervention import LarkSheetRecordForHumanIntervention
 from pqai_agent.user_manager import UserManager
@@ -239,7 +239,7 @@ class DialogueManager:
             (self.current_state, self.previous_state)
         ))
 
-    def update_state(self, message: Message) -> Tuple[bool, Optional[str]]:
+    def update_state(self, message: MqMessage) -> Tuple[bool, Optional[str]]:
         """根据用户消息更新对话状态,并返回是否需要发起回复 及下一条需处理的用户消息"""
         message_text = message.content
         message_ts = message.sendTime

+ 15 - 15
pqai_agent/message_queue_backend.py

@@ -11,21 +11,21 @@ from rocketmq import ClientConfiguration, Credentials, SimpleConsumer
 from pqai_agent import configs
 from pqai_agent import logging_service
 from pqai_agent.logging_service import logger
-from pqai_agent.mq_message import Message, MessageType, MessageChannel
+from pqai_agent.mq_message import MqMessage, MessageType, MessageChannel
 
 
 
 class MessageQueueBackend(abc.ABC):
     @abc.abstractmethod
-    def consume(self) -> Optional[Message]:
+    def consume(self) -> Optional[MqMessage]:
         pass
 
     @abc.abstractmethod
-    def ack(self, message: Message) -> None:
+    def ack(self, message: MqMessage) -> None:
         pass
 
     @abc.abstractmethod
-    def produce(self, message: Message, msg_group: Optional[str] = None) -> None:
+    def produce(self, message: MqMessage, msg_group: Optional[str] = None) -> None:
         pass
 
     @abc.abstractmethod
@@ -37,13 +37,13 @@ class MemoryQueueBackend(MessageQueueBackend):
     def __init__(self):
         self._queue = []
 
-    def consume(self) -> Optional[Message]:
+    def consume(self) -> Optional[MqMessage]:
         return self._queue.pop(0) if self._queue else None
 
-    def ack(self, message: Message):
+    def ack(self, message: MqMessage):
         return
 
-    def produce(self, message: Message, msg_group: Optional[str] = None):
+    def produce(self, message: MqMessage, msg_group: Optional[str] = None):
         self._queue.append(message)
 
     def shutdown(self):
@@ -78,7 +78,7 @@ class AliyunRocketMQQueueBackend(MessageQueueBackend):
     def __del__(self):
         self.shutdown()
 
-    def consume(self, invisible_duration=60) -> Optional[Message]:
+    def consume(self, invisible_duration=60) -> Optional[MqMessage]:
         if not self.has_consumer:
             raise Exception("Consumer not initialized.")
         # TODO(zhoutian): invisible_duration实际是不同消息类型不同的,有些消息预期的处理时间会更长
@@ -89,7 +89,7 @@ class AliyunRocketMQQueueBackend(MessageQueueBackend):
         body = rmq_message.body.decode('utf-8')
         logger.debug("[{}]recv message body: {}".format(self.topic, body))
         try:
-            message = Message.from_json(body)
+            message = MqMessage.from_json(body)
             message._rmq_message = rmq_message
         except Exception as e:
             logger.error("Invalid message: {}. Parsing error: {}".format(body, e))
@@ -98,13 +98,13 @@ class AliyunRocketMQQueueBackend(MessageQueueBackend):
             return None
         return message
 
-    def ack(self, message: Message):
+    def ack(self, message: MqMessage):
         if not message._rmq_message:
-            raise ValueError("Message not set with _rmq_message.")
+            raise ValueError("MqMessage not set with _rmq_message.")
         logger.debug("[{}]ack message: {}".format(self.topic, message))
         self.consumer.ack(message._rmq_message)
 
-    def produce(self, message: Message, msg_group: Optional[str] = None) -> None:
+    def produce(self, message: MqMessage, msg_group: Optional[str] = None) -> None:
         if not self.has_producer:
             raise Exception("Producer not initialized.")
         message.model_config['use_enum_values'] = False
@@ -148,16 +148,16 @@ if __name__ == '__main__':
         else:
             break
 
-    send_message = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.CORP_WECHAT,
+    send_message = MqMessage.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.CORP_WECHAT,
                                  "user_id_1", "staff_id_0",
-                                 None, int(time.time() * 1000))
+                                   None, int(time.time() * 1000))
     queue.produce(send_message)
     recv_message = queue.consume()
     print(recv_message)
     if recv_message:
         queue.ack(recv_message)
 
-    send_message = Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
+    send_message = MqMessage.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
                                  "user_id_1", "staff_id_0",
                                  "message_queue_backend test", int(time.time() * 1000))
     queue.produce(send_message)

+ 3 - 3
pqai_agent/mq_message.py

@@ -85,7 +85,7 @@ class MessageChannel(int, Enum):
             101: "系统内部"
         }[code]
 
-class Message(BaseModel):
+class MqMessage(BaseModel):
      msgId: Optional[int] = None
      type: MessageType
      channel: MessageChannel
@@ -102,7 +102,7 @@ class Message(BaseModel):
 
      @staticmethod
      def build(type, channel, sender, receiver, content, timestamp):
-         return Message(
+         return MqMessage(
              msgId=0,
              type=type,
              channel=channel,
@@ -120,4 +120,4 @@ class Message(BaseModel):
 
      @staticmethod
      def from_json(json_str):
-         return Message.model_validate_json(json_str)
+         return MqMessage.model_validate_json(json_str)

+ 5 - 5
pqai_agent_server/agent_server.py

@@ -6,7 +6,7 @@ from pqai_agent import configs, logging_service
 from pqai_agent.agent_service import AgentService
 from pqai_agent.chat_service import ChatServiceType
 from pqai_agent.logging_service import logger
-from pqai_agent.mq_message import MessageType, Message, MessageChannel
+from pqai_agent.mq_message import MessageType, MqMessage, MessageChannel
 from pqai_agent.message_queue_backend import AliyunRocketMQQueueBackend, MemoryQueueBackend
 from pqai_agent.push_service import PushTaskWorkerPool, PushScanThread
 from pqai_agent.user_manager import LocalUserManager, LocalUserRelationManager, MySQLUserManager, \
@@ -93,14 +93,14 @@ if __name__ == "__main__":
         receiver = '1688855931724582'
         if text in (MessageType.AGGREGATION_TRIGGER.name,
                     MessageType.HUMAN_INTERVENTION_END.name):
-            message = Message.build(
+            message = MqMessage.build(
                 MessageType.__members__.get(text),
                 MessageChannel.CORP_WECHAT,
                 sender, receiver, None, int(time.time() * 1000))
         else:
-            message = Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
-                                    sender,receiver, text, int(time.time() * 1000)
-                                    )
+            message = MqMessage.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
+                                      sender, receiver, text, int(time.time() * 1000)
+                                      )
         message.msgId = message_id
         receive_queue.produce(message)
         time.sleep(0.1)

+ 5 - 5
scripts/mq_sender.py

@@ -1,6 +1,6 @@
 import sys
 from pqai_agent import configs
-from pqai_agent.mq_message import Message, MessageType, MessageChannel
+from pqai_agent.mq_message import MqMessage, MessageType, MessageChannel
 from pqai_agent.message_queue_backend import AliyunRocketMQQueueBackend
 import time
 from argparse import ArgumentParser
@@ -33,13 +33,13 @@ if __name__ == '__main__':
         receiver = args.staff_id
         if text in (MessageType.AGGREGATION_TRIGGER.name,
                     MessageType.HUMAN_INTERVENTION_END.name):
-            message = Message.build(
+            message = MqMessage.build(
                 MessageType.__members__.get(text),
                 MessageChannel.CORP_WECHAT,
                 sender, receiver, None, int(time.time() * 1000))
         else:
-            message = Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
-                                    sender,receiver, text, int(time.time() * 1000)
-                                    )
+            message = MqMessage.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
+                                      sender, receiver, text, int(time.time() * 1000)
+                                      )
         message.msgId = message_id
         receive_queue.produce(message)

+ 3 - 3
scripts/resend_lost_message.py

@@ -5,7 +5,7 @@
 from datetime import datetime
 import re
 from pqai_agent import configs
-from pqai_agent.mq_message import MessageChannel, Message, MessageType
+from pqai_agent.mq_message import MessageChannel, MqMessage, MessageType
 from pqai_agent.message_queue_backend import AliyunRocketMQQueueBackend
 from pqai_agent.user_manager import MySQLUserRelationManager
 
@@ -102,8 +102,8 @@ def main():
                     # Check if user has already been processed
                     if user_id in processed_users:
                         break
-                    message = Message.build(message_type, MessageChannel.CORP_WECHAT,
-                                            staff_id, user_id, response, int(timestamp.timestamp() * 1000))
+                    message = MqMessage.build(message_type, MessageChannel.CORP_WECHAT,
+                                              staff_id, user_id, response, int(timestamp.timestamp() * 1000))
                     print(message)
                     # Send the message
                     send_queue.produce(message)

+ 6 - 6
tests/unit_test.py

@@ -6,7 +6,7 @@ import pytest
 from unittest.mock import Mock, MagicMock
 from pqai_agent.agent_service import AgentService, MemoryQueueBackend
 from pqai_agent.dialogue_manager import DialogueState, TimeContext
-from pqai_agent.mq_message import MessageType, Message, MessageChannel
+from pqai_agent.mq_message import MessageType, MqMessage, MessageChannel
 from pqai_agent.response_type_detector import ResponseTypeDetector
 from pqai_agent.user_manager import LocalUserManager
 import time
@@ -113,7 +113,7 @@ def test_normal_conversation_flow(test_env):
     service.get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 0
 
     # 准备测试消息
-    test_msg = Message.build(
+    test_msg = MqMessage.build(
         MessageType.TEXT, MessageChannel.CORP_WECHAT,
         'user_id_0', 'staff_id_0', '你好', int(time.time() * 1000))
     queues.receive_queue.produce(test_msg)
@@ -136,11 +136,11 @@ def test_aggregated_conversation_flow(test_env):
 
     # 准备测试消息
     ts_begin = int(time.time() * 1000)
-    test_msg = Message.build(
+    test_msg = MqMessage.build(
         MessageType.TEXT, MessageChannel.CORP_WECHAT,
         'user_id_0', 'staff_id_0', '你好', ts_begin)
     queues.receive_queue.produce(test_msg)
-    test_msg = Message.build(
+    test_msg = MqMessage.build(
         MessageType.TEXT, MessageChannel.CORP_WECHAT,
         'user_id_0', 'staff_id_0', '我是老李', ts_begin + 500)
     queues.receive_queue.produce(test_msg)
@@ -162,7 +162,7 @@ def test_aggregated_conversation_flow(test_env):
     assert sent_msg is None
 
     # 模拟定时器产生空消息触发响应
-    service.process_single_message(Message.build(
+    service.process_single_message(MqMessage.build(
         MessageType.AGGREGATION_TRIGGER, MessageChannel.CORP_WECHAT,
         'user_id_0', 'staff_id_0', None, ts_begin + 2000
     ))
@@ -178,7 +178,7 @@ def test_human_intervention_trigger(test_env):
     service.get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 0
 
     # 准备需要人工干预的消息
-    test_msg = Message.build(
+    test_msg = MqMessage.build(
         MessageType.TEXT, MessageChannel.CORP_WECHAT,
         "user_id_0", "staff_id_0",
         "我需要帮助!", int(time.time() * 1000)