Przeglądaj źródła

Update dialogue_manager: some refactor

StrayWarrior 1 tydzień temu
rodzic
commit
5871cdaee5
1 zmienionych plików z 63 dodań i 45 usunięć
  1. 63 45
      dialogue_manager.py

+ 63 - 45
dialogue_manager.py

@@ -428,7 +428,8 @@ class DialogueManager:
 
         return context
 
-    def _select_prompt(self, state):
+    @staticmethod
+    def _select_prompt(state):
         state_to_prompt_map = {
             DialogueState.GREETING: GENERAL_GREETING_PROMPT,
             DialogueState.CHITCHAT: CHITCHAT_PROMPT_COZE,
@@ -436,7 +437,8 @@ class DialogueManager:
         }
         return state_to_prompt_map[state]
 
-    def _select_coze_bot(self, state):
+    @staticmethod
+    def _select_coze_bot(state):
         state_to_bot_map = {
             DialogueState.GREETING: '7486112546798780425',
             DialogueState.CHITCHAT: '7491300566573301770',
@@ -449,6 +451,58 @@ class DialogueManager:
         prompt = prompt_template.format(**prompt_context)
         return {'role': 'system', 'content': prompt}
 
+    @staticmethod
+    def compose_chat_messages_openai_compatible(dialogue_history, current_time):
+        messages = []
+        for entry in dialogue_history:
+            role = entry['role']
+            fmt_time = DialogueManager.format_timestamp(entry['timestamp'])
+            messages.append({
+                "role": role,
+                "content": '[{}] {}'.format(fmt_time, entry["content"])
+            })
+        # 添加一条前缀用于 约束时间场景
+        msg_prefix = '[{}]'.format(current_time)
+        messages.append({'role': 'assistant', 'content': msg_prefix})
+        return messages
+
+    @staticmethod
+    def compose_chat_messages_coze(dialogue_history, current_time, staff_id, user_id):
+        messages = []
+        # 如果system后的第1条消息不为user,需要补一条user消息
+        if len(dialogue_history) > 0 and dialogue_history[0]['role'] != 'user':
+            fmt_time = DialogueManager.format_timestamp(dialogue_history[0]['timestamp'])
+            messages.append(cozepy.Message.build_user_question_text(f'[{fmt_time}] '))
+        # coze最后一条消息必须为user,且可能吞掉连续的user消息,故强制增加一条空消息(可参与合并)
+        dialogue_history.append({
+            'role': 'user',
+            'content': ' ',
+            'timestamp': int(datetime.strptime(current_time, '%Y-%m-%d %H:%M:%S').timestamp() * 1000),
+        })
+        # 将连续的同一角色的消息做聚合,避免coze吞消息
+        messages_to_aggr = []
+        last_message_role = None
+        for entry in dialogue_history:
+            if not entry['content']:
+                logger.warning("staff[{}], user[{}], role[{}]: empty content in dialogue history".format(
+                    staff_id, user_id, entry['role']
+                ))
+                continue
+            role = entry['role']
+            if role != last_message_role:
+                if messages_to_aggr:
+                    aggregated_message = '\n'.join(messages_to_aggr)
+                    messages.append(DialogueManager.build_chat_message(
+                        last_message_role, aggregated_message, ChatServiceType.COZE_CHAT))
+                messages_to_aggr = []
+                last_message_role = role
+            messages_to_aggr.append(DialogueManager.format_dialogue_content(entry))
+        if messages_to_aggr:
+            aggregated_message = '\n'.join(messages_to_aggr)
+            messages.append(DialogueManager.build_chat_message(
+                last_message_role, aggregated_message, ChatServiceType.COZE_CHAT))
+        return messages
+
     def build_chat_configuration(
             self,
             user_message: Optional[str] = None,
@@ -472,54 +526,18 @@ class DialogueManager:
         if overwrite_context:
             prompt_context.update(overwrite_context)
 
+        # FIXME(zhoutian): time in string type
+        current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
+        if overwrite_context and 'current_time' in overwrite_context:
+            current_time = overwrite_context.get('current_time')
+
         if chat_service_type == ChatServiceType.OPENAI_COMPATIBLE:
             system_message = self._create_system_message(prompt_context)
             messages.append(system_message)
-            for entry in dialogue_history:
-                role = entry['role']
-                fmt_time = self.format_timestamp(entry['timestamp'])
-                messages.append({
-                    "role": role,
-                    "content": '[{}] {}'.format(fmt_time, entry["content"])
-                })
-            # 添加一条前缀用于 约束时间场景
-            msg_prefix = '[{}]'.format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
-            messages.append({'role': 'assistant', 'content': msg_prefix})
+            messages.extend(self.compose_chat_messages_openai_compatible(dialogue_history, current_time))
         elif chat_service_type == ChatServiceType.COZE_CHAT:
             dialogue_history = dialogue_history[-95:] # Coze最多支持100条,还需要附加系统消息
-            # 如果system后的第1条消息不为user,需要补一条user消息
-            if len(dialogue_history) > 0 and dialogue_history[0]['role'] != 'user':
-                fmt_time = self.format_timestamp(dialogue_history[0]['timestamp'])
-                messages.append(cozepy.Message.build_user_question_text(f'[{fmt_time}] '))
-            # coze最后一条消息必须为user,且可能吞掉连续的user消息,故强制增加一条空消息(可参与合并)
-            current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
-            if overwrite_context and 'current_time' in overwrite_context:
-                current_time = overwrite_context.get('current_time')
-            dialogue_history.append({
-                'role': 'user',
-                'content': ' ',
-                'timestamp': int(datetime.strptime(current_time, '%Y-%m-%d %H:%M:%S').timestamp() * 1000),
-            })
-            # 将连续的同一角色的消息做聚合,避免coze吞消息
-            messages_to_aggr = []
-            last_message_role = None
-            for entry in dialogue_history:
-                if not entry['content']:
-                    logger.warning("staff[{}], user[{}], role[{}]: empty content in dialogue history".format(
-                        self.staff_id, self.user_id, entry['role']
-                    ))
-                    continue
-                role = entry['role']
-                if role != last_message_role:
-                    if messages_to_aggr:
-                        aggregated_message = '\n'.join(messages_to_aggr)
-                        messages.append(self.build_chat_message(last_message_role, aggregated_message, chat_service_type))
-                    messages_to_aggr = []
-                    last_message_role = role
-                messages_to_aggr.append(self.format_dialogue_content(entry))
-            if messages_to_aggr:
-                aggregated_message = '\n'.join(messages_to_aggr)
-                messages.append(self.build_chat_message(last_message_role, aggregated_message, chat_service_type))
+            messages = self.compose_chat_messages_coze(dialogue_history, current_time, self.staff_id, self.user_id)
             custom_variables = {}
             for k, v in prompt_context.items():
                 custom_variables[k] = str(v)