|
@@ -18,6 +18,7 @@ import logging_service
|
|
|
import prompt_templates
|
|
|
from dialogue_manager import DialogueManager
|
|
|
from history_dialogue_service import HistoryDialogueService
|
|
|
+from response_type_detector import ResponseTypeDetector
|
|
|
from user_manager import MySQLUserManager, MySQLUserRelationManager
|
|
|
from user_profile_extractor import UserProfileExtractor
|
|
|
|
|
@@ -113,7 +114,8 @@ def list_scenes():
|
|
|
scenes = [
|
|
|
{'scene': 'greeting', 'display_name': '问候'},
|
|
|
{'scene': 'chitchat', 'display_name': '闲聊'},
|
|
|
- {'scene': 'profile_extractor', 'display_name': '画像提取'}
|
|
|
+ {'scene': 'profile_extractor', 'display_name': '画像提取'},
|
|
|
+ {'scene': 'response_type_detector', 'display_name': '回复模态判断'}
|
|
|
]
|
|
|
return wrap_response(200, data=scenes)
|
|
|
|
|
@@ -123,12 +125,14 @@ def get_base_prompt():
|
|
|
prompt_map = {
|
|
|
'greeting': prompt_templates.GENERAL_GREETING_PROMPT,
|
|
|
'chitchat': prompt_templates.CHITCHAT_PROMPT_COZE,
|
|
|
- 'profile_extractor': prompt_templates.USER_PROFILE_EXTRACT_PROMPT
|
|
|
+ 'profile_extractor': prompt_templates.USER_PROFILE_EXTRACT_PROMPT,
|
|
|
+ 'response_type_detector': prompt_templates.RESPONSE_TYPE_DETECT_PROMPT
|
|
|
}
|
|
|
model_map = {
|
|
|
'greeting': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
|
|
|
'chitchat': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
|
|
|
- 'profile_extractor': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5
|
|
|
+ 'profile_extractor': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
|
|
|
+ 'response_type_detector': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5
|
|
|
}
|
|
|
if scene not in prompt_map:
|
|
|
return wrap_response(404, msg='scene not found')
|
|
@@ -213,6 +217,24 @@ def run_chat_prompt(req_data):
|
|
|
messages.extend(DialogueManager.compose_chat_messages_openai_compatible(dialogue_history, current_time_str))
|
|
|
return run_openai_chat(messages, model_name, temperature=1, top_p=0.7, max_tokens=1024)
|
|
|
|
|
|
+def run_response_type_prompt(req_data):
|
|
|
+ prompt = req_data['prompt']
|
|
|
+ dialogue_history = req_data['dialogue_history']
|
|
|
+ model_name = req_data['model_name']
|
|
|
+
|
|
|
+ composed_dialogue = ResponseTypeDetector.compose_dialogue(dialogue_history[:-1])
|
|
|
+ next_message = DialogueManager.format_dialogue_content(dialogue_history[-1])
|
|
|
+ prompt = prompt.format(
|
|
|
+ dialogue_history=composed_dialogue,
|
|
|
+ message=next_message
|
|
|
+ )
|
|
|
+ messages = [
|
|
|
+ {'role': 'system', 'content': '你是一个专业的智能助手'},
|
|
|
+ {'role': 'user', 'content': prompt}
|
|
|
+ ]
|
|
|
+ return run_openai_chat(messages, model_name,temperature=0.2, max_tokens=128)
|
|
|
+
|
|
|
+
|
|
|
@app.route('/api/runPrompt', methods=['POST'])
|
|
|
def run_prompt():
|
|
|
try:
|
|
@@ -222,6 +244,9 @@ def run_prompt():
|
|
|
if scene == 'profile_extractor':
|
|
|
response = run_extractor_prompt(req_data)
|
|
|
return wrap_response(200, data=response)
|
|
|
+ elif scene == 'response_type_detector':
|
|
|
+ response = run_response_type_prompt(req_data)
|
|
|
+ return wrap_response(200, data=response)
|
|
|
else:
|
|
|
response = run_chat_prompt(req_data)
|
|
|
return wrap_response(200, data=response.choices[0].message.content)
|