Procházet zdrojové kódy

Update dialogue_manager: support images in coze

StrayWarrior před 1 dnem
rodič
revize
9e4b9bad10
1 změnil soubory, kde provedl 37 přidání a 3 odebrání
  1. 37 3
      dialogue_manager.py

+ 37 - 3
dialogue_manager.py

@@ -514,6 +514,7 @@ class DialogueManager:
         })
         # 将连续的同一角色的消息做聚合,避免coze吞消息
         messages_to_aggr = []
+        objects_to_aggr = []
         last_message_role = None
         for entry in dialogue_history:
             if not entry['content']:
@@ -523,14 +524,47 @@ class DialogueManager:
                 continue
             role = entry['role']
             if role != last_message_role:
-                if messages_to_aggr:
+                if objects_to_aggr:
+                    if last_message_role != 'user':
+                        pass
+                    else:
+                        text_message = '\n'.join(messages_to_aggr)
+                        object_string_list = []
+                        for entry in objects_to_aggr:
+                            # FIXME: 其它消息类型的支持
+                            object_string_list.append(cozepy.MessageObjectString.build_image(file_url=entry['content']))
+                        object_string_list.append(cozepy.MessageObjectString.build_text(text_message))
+                        messages.append(cozepy.Message.build_user_question_objects(object_string_list))
+                elif messages_to_aggr:
                     aggregated_message = '\n'.join(messages_to_aggr)
                     messages.append(DialogueManager.build_chat_message(
                         last_message_role, aggregated_message, ChatServiceType.COZE_CHAT))
+                objects_to_aggr = []
                 messages_to_aggr = []
                 last_message_role = role
-            messages_to_aggr.append(DialogueManager.format_dialogue_content(entry))
-        if messages_to_aggr:
+            if entry.get('type', MessageType.TEXT) in (MessageType.IMAGE_GW, MessageType.IMAGE_QW):
+                # 多模态消息必须用特殊的聚合方式,一个object_string数组中只能有一个文字消息,但可以有多个图片
+                if role == 'user':
+                    objects_to_aggr.append(entry)
+                else:
+                    logger.warning("staff[{}], user[{}]: unsupported message type [{}] in assistant role".format(
+                        staff_id, user_id, entry['type']
+                    ))
+            else:
+                messages_to_aggr.append(DialogueManager.format_dialogue_content(entry))
+        # 如果有未聚合的object消息,需要特殊处理
+        if objects_to_aggr:
+            if last_message_role != 'user':
+                pass
+            else:
+                text_message = '\n'.join(messages_to_aggr)
+                object_string_list = []
+                for entry in objects_to_aggr:
+                    # FIXME: 其它消息类型的支持
+                    object_string_list.append(cozepy.MessageObjectString.build_image(file_url=entry['content']))
+                object_string_list.append(cozepy.MessageObjectString.build_text(text_message))
+                messages.append(cozepy.Message.build_user_question_objects(object_string_list))
+        elif messages_to_aggr:
             aggregated_message = '\n'.join(messages_to_aggr)
             messages.append(DialogueManager.build_chat_message(
                 last_message_role, aggregated_message, ChatServiceType.COZE_CHAT))