浏览代码

Add Message abstraction

StrayWarrior 7 月之前
父节点
当前提交
a648396eec
共有 6 个文件被更改,包括 129 次插入78 次删除
  1. 28 29
      agent_service.py
  2. 16 12
      dialogue_manager.py
  3. 50 10
      message.py
  4. 4 2
      message_queue_backend.py
  5. 23 0
      prompt_templates.py
  6. 8 25
      user_profile_extractor.py

+ 28 - 29
agent_service.py

@@ -20,7 +20,7 @@ from openai import OpenAI
 from message_queue_backend import MessageQueueBackend, MemoryQueueBackend
 from message_queue_backend import MessageQueueBackend, MemoryQueueBackend
 from user_profile_extractor import UserProfileExtractor
 from user_profile_extractor import UserProfileExtractor
 import threading
 import threading
-from message import MessageType
+from message import MessageType, Message, MessageChannel
 from logging_service import ColoredFormatter
 from logging_service import ColoredFormatter
 
 
 
 
@@ -92,19 +92,15 @@ class AgentService:
 
 
     def _schedule_aggregation_trigger(self, user_id: str, delay_sec: int):
     def _schedule_aggregation_trigger(self, user_id: str, delay_sec: int):
         logging.debug("user: {}, schedule trigger message after {} seconds".format(user_id, delay_sec))
         logging.debug("user: {}, schedule trigger message after {} seconds".format(user_id, delay_sec))
-        message = {
-            'user_id': user_id,
-            'type': MessageType.AGGREGATION_TRIGGER,
-            'text': None,
-            'timestamp': int(time.time() * 1000) + delay_sec * 1000
-        }
+        message_ts = int((time.time() + delay_sec) * 1000)
+        message = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.SYSTEM, None, user_id, None, message_ts)
+        message.id = -MessageType.AGGREGATION_TRIGGER.code
         self.scheduler.add_job(lambda: self.receive_queue.produce(message),
         self.scheduler.add_job(lambda: self.receive_queue.produce(message),
                                'date',
                                'date',
                                run_date=datetime.now() + timedelta(seconds=delay_sec))
                                run_date=datetime.now() + timedelta(seconds=delay_sec))
 
 
-    def process_single_message(self, message: Dict):
-        user_id = message['user_id']
-        message_text = message.get('text', None)
+    def process_single_message(self, message: Message):
+        user_id = message.user_id
 
 
         # 获取用户信息和Agent实例
         # 获取用户信息和Agent实例
         user_profile = self.user_manager.get_user_profile(user_id)
         user_profile = self.user_manager.get_user_profile(user_id)
@@ -117,9 +113,9 @@ class AgentService:
 
 
         # 根据状态路由消息
         # 根据状态路由消息
         if agent.is_in_human_intervention():
         if agent.is_in_human_intervention():
-            self._route_to_human_intervention(user_id, message_text, dialogue_state)
+            self._route_to_human_intervention(user_id, message)
         elif dialogue_state == DialogueState.MESSAGE_AGGREGATING:
         elif dialogue_state == DialogueState.MESSAGE_AGGREGATING:
-            if message['type'] != MessageType.AGGREGATION_TRIGGER:
+            if message.type != MessageType.AGGREGATION_TRIGGER:
                 # 产生一个触发器,但是不能由触发器递归产生
                 # 产生一个触发器,但是不能由触发器递归产生
                 logging.debug("user: {}, waiting next message for aggregation".format(user_id))
                 logging.debug("user: {}, waiting next message for aggregation".format(user_id))
                 self._schedule_aggregation_trigger(user_id, agent.message_aggregation_sec)
                 self._schedule_aggregation_trigger(user_id, agent.message_aggregation_sec)
@@ -129,13 +125,16 @@ class AgentService:
             self._update_user_profile(user_id, user_profile, message_text)
             self._update_user_profile(user_id, user_profile, message_text)
             self._get_chat_response(user_id, agent, message_text)
             self._get_chat_response(user_id, agent, message_text)
 
 
-    def _route_to_human_intervention(self, user_id: str, user_message: str, state: DialogueState):
+    def _route_to_human_intervention(self, user_id: str, origin_message: Message):
         """路由到人工干预"""
         """路由到人工干预"""
-        self.human_queue.produce({
-            'user_id': user_id,
-            'state': state,
-            'timestamp': datetime.now().isoformat()
-        })
+        self.human_queue.produce(Message.build(
+            MessageType.TEXT,
+            origin_message.channel,
+            origin_message.staff_id,
+            origin_message.user_id,
+            "用户对话需人工介入,用户名:{}".format(user_id),
+            int(time.time() * 1000)
+        ))
 
 
     def _check_initiative_conversations(self):
     def _check_initiative_conversations(self):
         """定时检查主动发起对话"""
         """定时检查主动发起对话"""
@@ -154,7 +153,6 @@ class AgentService:
         """处理LLM响应"""
         """处理LLM响应"""
         chat_config = agent.build_chat_configuration(user_message, self.chat_service_type)
         chat_config = agent.build_chat_configuration(user_message, self.chat_service_type)
         logging.debug(chat_config)
         logging.debug(chat_config)
-        # FIXME(zhoutian): 这里的抽象不够好,DialogueManager和AgentService有耦合
         chat_response = self._call_chat_api(chat_config)
         chat_response = self._call_chat_api(chat_config)
 
 
         if response := agent.generate_response(chat_response):
         if response := agent.generate_response(chat_response):
@@ -187,7 +185,7 @@ class AgentService:
 if __name__ == "__main__":
 if __name__ == "__main__":
     logging.getLogger().setLevel(logging.DEBUG)
     logging.getLogger().setLevel(logging.DEBUG)
     console_handler = logging.StreamHandler()
     console_handler = logging.StreamHandler()
-    console_handler.setLevel(logging.INFO)
+    console_handler.setLevel(logging.DEBUG)
     formatter = ColoredFormatter(
     formatter = ColoredFormatter(
         '%(asctime)s - %(funcName)s[%(lineno)d] - %(levelname)s - %(message)s'
         '%(asctime)s - %(funcName)s[%(lineno)d] - %(levelname)s - %(message)s'
     )
     )
@@ -220,15 +218,16 @@ if __name__ == "__main__":
     process_thread = threading.Thread(target=service.process_messages)
     process_thread = threading.Thread(target=service.process_messages)
     process_thread.start()
     process_thread.start()
 
 
+    message_id = 0
     while True:
     while True:
         print("Input next message: ")
         print("Input next message: ")
-        message = sys.stdin.readline().strip()
-        message_dict = {
-            "user_id": "user_id_1",
-            "type": MessageType.TEXT,
-            "text": message,
-            "timestamp": int(time.time() * 1000)
-        }
-        if message:
-            receive_queue.produce(message_dict)
+        text = sys.stdin.readline().strip()
+        if not text:
+            continue
+        message_id += 1
+        message = Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
+            'staff_id_1','user_id_1', text, int(time.time() * 1000)
+        )
+        message.id = message_id
+        receive_queue.produce(message)
         time.sleep(0.1)
         time.sleep(0.1)

+ 16 - 12
dialogue_manager.py

@@ -11,7 +11,7 @@ import logging
 import cozepy
 import cozepy
 
 
 from chat_service import ChatServiceType
 from chat_service import ChatServiceType
-from message import MessageType
+from message import MessageType, Message
 # from vector_memory_manager import VectorMemoryManager
 # from vector_memory_manager import VectorMemoryManager
 from structured_memory_manager import StructuredMemoryManager
 from structured_memory_manager import StructuredMemoryManager
 from user_manager import UserManager
 from user_manager import UserManager
@@ -84,10 +84,10 @@ class DialogueManager:
         else:
         else:
             return TimeContext.NIGHT
             return TimeContext.NIGHT
 
 
-    def update_state(self, message: Dict) -> Tuple[DialogueState, str]:
+    def update_state(self, message: Message) -> Tuple[DialogueState, str]:
         """根据用户消息更新对话状态,并返回下一条需处理的用户消息"""
         """根据用户消息更新对话状态,并返回下一条需处理的用户消息"""
-        message_text = message.get('text', None)
-        message_ts = message['timestamp']
+        message_text = message.content
+        message_ts = message.timestamp
         # 如果当前已经是人工介入状态,保持该状态
         # 如果当前已经是人工介入状态,保持该状态
         if self.current_state == DialogueState.HUMAN_INTERVENTION:
         if self.current_state == DialogueState.HUMAN_INTERVENTION:
             # 记录对话历史,但不改变状态
             # 记录对话历史,但不改变状态
@@ -102,7 +102,7 @@ class DialogueManager:
         # 检查是否处于消息聚合状态
         # 检查是否处于消息聚合状态
         if self.current_state == DialogueState.MESSAGE_AGGREGATING:
         if self.current_state == DialogueState.MESSAGE_AGGREGATING:
             # 收到的是特殊定时触发的空消息,且在聚合中,且已经超时,恢复之前状态,继续处理
             # 收到的是特殊定时触发的空消息,且在聚合中,且已经超时,恢复之前状态,继续处理
-            if message['type'] == MessageType.AGGREGATION_TRIGGER \
+            if message.type == MessageType.AGGREGATION_TRIGGER \
                     and message_ts - self.last_interaction_time > self.message_aggregation_sec * 1000:
                     and message_ts - self.last_interaction_time > self.message_aggregation_sec * 1000:
                 logging.debug("user_id: {}, last interaction time: {}".format(
                 logging.debug("user_id: {}, last interaction time: {}".format(
                     self.user_id, datetime.fromtimestamp(self.last_interaction_time / 1000)))
                     self.user_id, datetime.fromtimestamp(self.last_interaction_time / 1000)))
@@ -113,7 +113,7 @@ class DialogueManager:
                     self.unprocessed_messages.append(message_text)
                     self.unprocessed_messages.append(message_text)
                     self.last_interaction_time = message_ts
                     self.last_interaction_time = message_ts
                 return self.current_state, message_text
                 return self.current_state, message_text
-        elif message['type'] != MessageType.AGGREGATION_TRIGGER and self.message_aggregation_sec > 0:
+        elif message.type != MessageType.AGGREGATION_TRIGGER and self.message_aggregation_sec > 0:
             # 收到有内容的用户消息,切换到消息聚合状态
             # 收到有内容的用户消息,切换到消息聚合状态
             self.previous_state = self.current_state
             self.previous_state = self.current_state
             self.current_state = DialogueState.MESSAGE_AGGREGATING
             self.current_state = DialogueState.MESSAGE_AGGREGATING
@@ -164,18 +164,18 @@ class DialogueManager:
             self.dialogue_history.append({
             self.dialogue_history.append({
                 "role": "user",
                 "role": "user",
                 "content": message_text,
                 "content": message_text,
-                "timestamp": int(time.time() * 1000),
+                "timestamp": message_ts,
                 "state": self.current_state.name
                 "state": self.current_state.name
             })
             })
 
 
         return self.current_state, message_text
         return self.current_state, message_text
 
 
-    def _determine_state_from_message(self, message: str) -> DialogueState:
+    def _determine_state_from_message(self, message_text: str) -> DialogueState:
         """根据消息内容确定对话状态"""
         """根据消息内容确定对话状态"""
-        if not message:
+        if not message_text:
             return self.current_state
             return self.current_state
         # 简单的规则-关键词匹配
         # 简单的规则-关键词匹配
-        message_lower = message.lower()
+        message_lower = message_text.lower()
 
 
         # 判断是否是复杂请求
         # 判断是否是复杂请求
         complex_request_keywords = ["帮我", "怎么办", "我需要", "麻烦你", "请帮助", "急", "紧急"]
         complex_request_keywords = ["帮我", "怎么办", "我需要", "麻烦你", "请帮助", "急", "紧急"]
@@ -337,7 +337,8 @@ class DialogueManager:
             # "dialogue_history": self.dialogue_history[-10:],
             # "dialogue_history": self.dialogue_history[-10:],
             "last_interaction_interval": self._get_hours_since_last_interaction(2),
             "last_interaction_interval": self._get_hours_since_last_interaction(2),
             "if_first_interaction": False,
             "if_first_interaction": False,
-            "if_active_greeting": False if user_message else True
+            "if_active_greeting": False if user_message else True,
+            **self.user_profile
         }
         }
 
 
         # 获取长期记忆
         # 获取长期记忆
@@ -365,7 +366,7 @@ class DialogueManager:
 
 
     def _create_system_message(self, prompt_context):
     def _create_system_message(self, prompt_context):
         prompt_template = self._select_prompt(self.current_state)
         prompt_template = self._select_prompt(self.current_state)
-        prompt = prompt_template.format(**prompt_context['user_profile'], **prompt_context)
+        prompt = prompt_template.format(**prompt_context)
         return {'role': 'system', 'content': prompt}
         return {'role': 'system', 'content': prompt}
 
 
     def build_chat_configuration(
     def build_chat_configuration(
@@ -405,8 +406,11 @@ class DialogueManager:
             custom_variables = {}
             custom_variables = {}
             for k, v in prompt_context.items():
             for k, v in prompt_context.items():
                 custom_variables[k] = str(v)
                 custom_variables[k] = str(v)
+            custom_variables.pop('user_profile', None)
             config['custom_variables'] = custom_variables
             config['custom_variables'] = custom_variables
             config['bot_id'] = self._select_coze_bot(self.current_state)
             config['bot_id'] = self._select_coze_bot(self.current_state)
+            if not user_message:
+                messages.append(cozepy.Message.build_user_question_text('请开始对话'))
         config['messages'] = messages
         config['messages'] = messages
 
 
         return config
         return config

+ 50 - 10
message.py

@@ -4,18 +4,58 @@
 
 
 
 
 from enum import Enum, auto
 from enum import Enum, auto
+from typing import Optional
+
+from pydantic import BaseModel
 
 
 class MessageType(Enum):
 class MessageType(Enum):
-    TEXT = auto()
-    AUDIO = auto()
-    IMAGE = auto()
-    VIDEO = auto()
-    MINIGRAM = auto()
-    LINK = auto()
+    TEXT = (1, "文本")
+    AUDIO = (2, "音频")
+    IMAGE = (3, "图片")
+    VIDEO = (4, "视频")
+    MINI_PROGRAM = (5, "小程序")
+    LINK = (6, "链接")
+
+    ACTIVE_TRIGGER = (101, "主动触发器")
+    AGGREGATION_TRIGGER = (102, "消息聚合触发器")
 
 
-    ACTIVE_TRIGGER = auto()
-    AGGREGATION_TRIGGER = auto()
+    def __init__(self, code, description):
+        self.code = code
+        self.description = description
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}.{self.name}"
 
 
 class MessageChannel(Enum):
 class MessageChannel(Enum):
-    CORP_WECHAT = auto()
-    MINI_PROGRAM = auto()
+    CORP_WECHAT = (1, "企业微信")
+    MINI_PROGRAM = (2, "小程序")
+
+    SYSTEM = (101, "系统内部")
+
+    def __init__(self, code, description):
+        self.code = code
+        self.description = description
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}.{self.name}"
+
+class Message(BaseModel):
+     id: int
+     type: MessageType
+     channel: MessageChannel
+     staff_id: Optional[str] = None
+     user_id: str
+     content: Optional[str] = None
+     timestamp: int
+
+     @staticmethod
+     def build(type, channel, staff_id, user_id, content, timestamp):
+         return Message(
+             id=0,
+             type=type,
+             channel=channel,
+             staff_id=staff_id,
+             user_id=user_id,
+             content=content,
+             timestamp=timestamp
+         )

+ 4 - 2
message_queue_backend.py

@@ -5,6 +5,8 @@
 import abc
 import abc
 from typing import Dict, Any
 from typing import Dict, Any
 
 
+from message import Message
+
 
 
 class MessageQueueBackend(abc.ABC):
 class MessageQueueBackend(abc.ABC):
     @abc.abstractmethod
     @abc.abstractmethod
@@ -12,7 +14,7 @@ class MessageQueueBackend(abc.ABC):
         pass
         pass
 
 
     @abc.abstractmethod
     @abc.abstractmethod
-    def produce(self, message: Dict) -> None:
+    def produce(self, message: Message) -> None:
         pass
         pass
 
 
 class MemoryQueueBackend(MessageQueueBackend):
 class MemoryQueueBackend(MessageQueueBackend):
@@ -23,5 +25,5 @@ class MemoryQueueBackend(MessageQueueBackend):
     def consume(self):
     def consume(self):
         return self._queue.pop(0) if self._queue else None
         return self._queue.pop(0) if self._queue else None
 
 
-    def produce(self, message: Dict):
+    def produce(self, message: Message):
         self._queue.append(message)
         self._queue.append(message)

+ 23 - 0
prompt_templates.py

@@ -171,3 +171,26 @@ CHITCHAT_PROMPT_V2 = """
 # 输出
 # 输出
   对话回复
   对话回复
 """
 """
+
+
+USER_PROFILE_EXTRACT_PROMPT = """
+请在已有的用户画像的基础上,仔细分析以下对话内容,完善用户的画像信息。
+已知信息(可能为空):
+- 姓名:{name}
+- 希望的称呼:{preferred_nickname}
+- 年龄:{age}
+- 地区:{region}
+- 健康状况:{health_conditions}
+- 兴趣爱好:{interests}
+
+对话历史:
+{dialogue_history}
+
+提取要求:
+1. 尽可能准确地识别用户的年龄、兴趣爱好、健康状况
+2. 关注用户生活、家庭等隐性信息
+3. 信息提取一定要有很高的准确性!如果无法确定具体信息,一定不要猜测!
+4. 兴趣爱好必须是用户明确提到喜欢参与的活动,且只保留最关键的5项。一定不要猜测!一定不要轻易把用户的常规话题和需求当作兴趣爱好!
+
+请使用update_user_profile函数返回需要更新的信息,注意不要返回无需更新的信息。
+"""

+ 8 - 25
user_profile_extractor.py

@@ -4,7 +4,7 @@
 
 
 import json
 import json
 from typing import Dict, Any, Optional
 from typing import Dict, Any, Optional
-from datetime import datetime
+from prompt_templates import USER_PROFILE_EXTRACT_PROMPT
 from openai import OpenAI
 from openai import OpenAI
 import logging
 import logging
 
 
@@ -45,7 +45,7 @@ class UserProfileExtractor:
                         },
                         },
                         "region": {
                         "region": {
                             "type": "string",
                             "type": "string",
-                            "description": "用户所在地"
+                            "description": "用户常驻的地区,不是用户临时所在地"
                         },
                         },
                         "interests": {
                         "interests": {
                             "type": "array",
                             "type": "array",
@@ -69,27 +69,7 @@ class UserProfileExtractor:
         """
         """
         context = user_profile.copy()
         context = user_profile.copy()
         context['dialogue_history'] = dialogue_history
         context['dialogue_history'] = dialogue_history
-        return """
-请在已有的用户画像的基础上,仔细分析以下对话内容,完善用户的画像信息。
-已知信息(可能为空):
-- 姓名:{name}
-- 希望的称呼:{preferred_nickname}
-- 年龄:{age}
-- 地区:{region}
-- 健康状况:{health_conditions}
-- 兴趣爱好:{interests}
-
-对话历史:
-{dialogue_history}
-
-提取要求:
-1. 尽可能准确地识别用户的年龄、兴趣爱好、健康状况
-2. 关注用户生活、家庭等隐性信息
-3. 信息提取需要有较高的置信度,兴趣爱好只保留用户明确喜欢且最关键的5项
-4. 如果无法确定具体信息,请不要猜测
-
-请使用update_user_profile函数返回需要更新的信息,注意不要返回无需更新的信息。
-""".format(**context)
+        return USER_PROFILE_EXTRACT_PROMPT.format(**context)
 
 
     def extract_profile_info(self, user_profile, dialogue_history: str) -> Optional[Dict]:
     def extract_profile_info(self, user_profile, dialogue_history: str) -> Optional[Dict]:
         """
         """
@@ -139,13 +119,16 @@ if __name__ == '__main__':
     extractor = UserProfileExtractor()
     extractor = UserProfileExtractor()
     current_profile = {
     current_profile = {
         'name': '',
         'name': '',
-        'preferred_nickname': '',
+        'preferred_nickname': '李叔',
         'age': 0,
         'age': 0,
-        'region': '',
+        'region': '北京',
         'health_conditions': [],
         'health_conditions': [],
         'medications': [],
         'medications': [],
         'interests': []
         'interests': []
     }
     }
+    message = "我回天津老家了"
+    resp = extractor.extract_profile_info(current_profile, message)
+    print(resp)
     message = "好的,孩子,我是老李头,今年68啦,住在北京海淀区。平时喜欢在微信上跟老伙伴们聊聊养生、下下象棋,偶尔也跟年轻人学学新鲜事儿。\n" \
     message = "好的,孩子,我是老李头,今年68啦,住在北京海淀区。平时喜欢在微信上跟老伙伴们聊聊养生、下下象棋,偶尔也跟年轻人学学新鲜事儿。\n" \
               "你叫我李叔就行,有啥事儿咱们慢慢聊啊\n" \
               "你叫我李叔就行,有啥事儿咱们慢慢聊啊\n" \
               "哎,今儿个天气不错啊,我刚才还去楼下小公园溜达了一圈儿。碰到几个老伙计在打太极,我也跟着比划了两下,这老胳膊老腿的,原来老不舒服,活动活动舒坦多了!\n" \
               "哎,今儿个天气不错啊,我刚才还去楼下小公园溜达了一圈儿。碰到几个老伙计在打太极,我也跟着比划了两下,这老胳膊老腿的,原来老不舒服,活动活动舒坦多了!\n" \