瀏覽代碼

Update Message definition

StrayWarrior 2 周之前
父節點
當前提交
854b6ab7a9
共有 3 個文件被更改,包括 117 次插入44 次删除
  1. 2 2
      agent_service.py
  2. 1 1
      dialogue_manager.py
  3. 114 41
      message.py

+ 2 - 2
agent_service.py

@@ -96,7 +96,7 @@ class AgentService:
         logging.debug("user: {}, schedule trigger message after {} seconds".format(user_id, delay_sec))
         message_ts = int((time.time() + delay_sec) * 1000)
         message = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.SYSTEM, user_id, staff_id, None, message_ts)
-        message.id = -MessageType.AGGREGATION_TRIGGER.code
+        message.msgId = -MessageType.AGGREGATION_TRIGGER.code
         self.scheduler.add_job(lambda: self.receive_queue.produce(message),
                                'date',
                                run_date=datetime.now() + timedelta(seconds=delay_sec))
@@ -234,6 +234,6 @@ if __name__ == "__main__":
         message = Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
             'user_id_1','staff_id_0', text, int(time.time() * 1000)
         )
-        message.id = message_id
+        message.msgId = message_id
         receive_queue.produce(message)
         time.sleep(0.1)

+ 1 - 1
dialogue_manager.py

@@ -88,7 +88,7 @@ class DialogueManager:
     def update_state(self, message: Message) -> Tuple[DialogueState, Optional[str]]:
         """根据用户消息更新对话状态,并返回下一条需处理的用户消息"""
         message_text = message.content
-        message_ts = message.timestamp
+        message_ts = message.sendTime
         # 如果当前已经是人工介入状态,保持该状态
         if self.current_state == DialogueState.HUMAN_INTERVENTION:
             # 记录对话历史,但不改变状态

+ 114 - 41
message.py

@@ -6,68 +6,141 @@
 from enum import Enum, auto
 from typing import Optional
 
+import rocketmq
 from pydantic import BaseModel
 
-class MessageType(Enum):
-    DEFAULT = (-1, "未分类的消息")
-    TEXT = (1, "文本")
-    VOICE = (2, "语音")
-    GIF = (3, "GIF")
-    IMAGE_GW = (4, "个微图片")
-    IMAGE_QW = (5, "企微图片")
-    MINI_PROGRAM = (6, "小程序")
-    LINK = (7, "链接")
-    SHI_PIN_HAO = (8, "视频号")
-    NAME_CARD = (9, "名片")
-    POSITION = (10, "位置")
-    RED_PACKET = (11, "红包")
-    FILE_GW = (12, "个微文件")
-    FILE_QW = (13, "企微文件")
-    VIDEO_GW = (14, "个微视频")
-    VIDEO_QW = (15, "企微视频")
-    AGGREGATION_MSG = (16, "聚合消息")
+# class MessageType(Enum):
+#     DEFAULT = (-1, "未分类的消息")
+#     TEXT = (1, "文本")
+#     VOICE = (2, "语音")
+#     GIF = (3, "GIF")
+#     IMAGE_GW = (4, "个微图片")
+#     IMAGE_QW = (5, "企微图片")
+#     MINI_PROGRAM = (6, "小程序")
+#     LINK = (7, "链接")
+#     SHI_PIN_HAO = (8, "视频号")
+#     NAME_CARD = (9, "名片")
+#     POSITION = (10, "位置")
+#     RED_PACKET = (11, "红包")
+#     FILE_GW = (12, "个微文件")
+#     FILE_QW = (13, "企微文件")
+#     VIDEO_GW = (14, "个微视频")
+#     VIDEO_QW = (15, "企微视频")
+#     AGGREGATION_MSG = (16, "聚合消息")
+#
+#     ACTIVE_TRIGGER = (101, "主动触发器")
+#     AGGREGATION_TRIGGER = (102, "消息聚合触发器")
+#
+#     def __init__(self, code, description):
+#         self.code = code
+#         self.description = description
+#
+#     def __repr__(self):
+#         return f"{self.__class__.__name__}.{self.name}"
 
-    ACTIVE_TRIGGER = (101, "主动触发器")
-    AGGREGATION_TRIGGER = (102, "消息聚合触发器")
+class MessageType(int, Enum):
+    DEFAULT = -1
+    TEXT = 1
+    VOICE = 2
+    GIF = 3
+    IMAGE_GW = 4
+    IMAGE_QW = 5
+    MINI_PROGRAM = 6
+    LINK = 7
+    SHI_PIN_HAO = 8
+    NAME_CARD = 9
+    POSITION = 10
+    RED_PACKET = 11
+    FILE_GW = 12
+    FILE_QW = 13
+    VIDEO_GW = 14
+    VIDEO_QW = 15
+    AGGREGATION_MSG = 16
 
-    def __init__(self, code, description):
-        self.code = code
-        self.description = description
+    ACTIVE_TRIGGER = 101
+    AGGREGATION_TRIGGER = 102
 
-    def __repr__(self):
-        return f"{self.__class__.__name__}.{self.name}"
+    def __init__(self, code):
+        self.description = {
+            -1: "未分类的消息",
+            1: "文本",
+            2: "语音",
+            3: "GIF",
+            4: "个微图片",
+            5: "企微图片",
+            6: "小程序",
+            7: "链接",
+            8: "视频号",
+            9: "名片",
+            10: "位置",
+            11: "红包",
+            12: "个微文件",
+            13: "企微文件",
+            14: "个微视频",
+            15: "企微视频",
+            16: "聚合消息",
+            101: "主动触发器",
+            102: "消息聚合触发器"
+        }[code]
 
-class MessageChannel(Enum):
-    CORP_WECHAT = (1, "企业微信")
-    MINI_PROGRAM = (2, "小程序")
+# class MessageChannel(Enum):
+#     CORP_WECHAT = (1, "企业微信")
+#     MINI_PROGRAM = (2, "小程序")
+#
+#     SYSTEM = (101, "系统内部")
+#
+#     def __init__(self, code, description):
+#         self.code = code
+#         self.description = description
+#
+#     def __repr__(self):
+#         return f"{self.__class__.__name__}.{self.name}"
 
-    SYSTEM = (101, "系统内部")
+class MessageChannel(int, Enum):
+    CORP_WECHAT = 1
+    MINI_PROGRAM = 2
+    SYSTEM = 101
 
-    def __init__(self, code, description):
-        self.code = code
-        self.description = description
-
-    def __repr__(self):
-        return f"{self.__class__.__name__}.{self.name}"
+    def __init__(self, code):
+        self.description = {
+            1: "企业微信",
+            2: "小程序",
+            101: "系统内部"
+        }[code]
 
 class Message(BaseModel):
-     id: int
+     msgId: Optional[int] = None
      type: MessageType
      channel: MessageChannel
      sender: Optional[str] = None
+     senderUnionId: Optional[str] = None
      receiver: str
      content: Optional[str] = None
-     timestamp: int
-     ref_msg_id: Optional[int] = None
+     # 由于需要和其它语言如Java进行序列化和反序列化交互,因此使用camelCase命名法
+     sendTime: int
+     refMsgId: Optional[int] = None
+
+     # 原始的RocketMQ消息体,用于ack
+     _rmq_message: Optional[rocketmq.Message] = None
 
      @staticmethod
      def build(type, channel, sender, receiver, content, timestamp):
          return Message(
-             id=0,
+             msgId=0,
              type=type,
              channel=channel,
              sender=sender,
              receiver=receiver,
              content=content,
-             timestamp=timestamp
-         )
+             sendTime=timestamp
+         )
+
+     def to_json(self):
+         return self.model_dump_json(include={
+             "msgId", "type", "channel", "sender", "senderUnionId",
+             "receiver", "content", "sendTime", "refMsgId"
+         })
+
+     @staticmethod
+     def from_json(json_str):
+         return Message.model_validate_json(json_str)