|
@@ -2,6 +2,7 @@
|
|
|
# -*- coding: utf-8 -*-
|
|
|
# vim:fenc=utf-8
|
|
|
|
|
|
+import random
|
|
|
from openai import OpenAI
|
|
|
from datetime import datetime
|
|
|
import chat_service
|
|
@@ -10,6 +11,7 @@ import prompt_templates
|
|
|
from dialogue_manager import DialogueManager
|
|
|
from logging_service import logger
|
|
|
from message import MessageType
|
|
|
+import re
|
|
|
|
|
|
|
|
|
class ResponseTypeDetector:
|
|
@@ -33,7 +35,7 @@ class ResponseTypeDetector:
|
|
|
)
|
|
|
self.model_name = chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5
|
|
|
|
|
|
- def detect_type(self, dialogue_history, next_message):
|
|
|
+ def detect_type(self, dialogue_history, next_message, enable_random=False):
|
|
|
if configs.get().get('debug_flags', {}).get('disable_llm_api_call', False):
|
|
|
return MessageType.TEXT
|
|
|
composed_dialogue = self.compose_dialogue(dialogue_history)
|
|
@@ -52,5 +54,25 @@ class ResponseTypeDetector:
|
|
|
response = response.choices[0].message.content.strip()
|
|
|
if response == '语音':
|
|
|
return MessageType.VOICE
|
|
|
- else:
|
|
|
- return MessageType.TEXT
|
|
|
+ if enable_random:
|
|
|
+ if self.if_message_suitable_for_voice(next_message):
|
|
|
+ if random.random() < 0.2:
|
|
|
+ logger.info(f"enable voice response randomly for message: {next_message}")
|
|
|
+ return MessageType.VOICE
|
|
|
+ return MessageType.TEXT
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def is_chinese_only(text):
|
|
|
+ # 匹配中文字符和中文标点
|
|
|
+ pattern = re.compile(r'^[\u4e00-\u9fa5\u3000-\u303f\uff00-\uffef]+$')
|
|
|
+ return bool(pattern.fullmatch(text))
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def if_message_suitable_for_voice(message):
|
|
|
+ # 使用语音的文字不适合过长
|
|
|
+ if len(message) > 30:
|
|
|
+ return False
|
|
|
+ # 只有纯文字的消息适合使用语音
|
|
|
+ if not ResponseTypeDetector.is_chinese_only(message):
|
|
|
+ return False
|
|
|
+ return True
|