فهرست منبع

Update api_server: add response_type_detector

StrayWarrior 11 ساعت پیش
والد
کامیت
4146b7a1c6
1فایلهای تغییر یافته به همراه28 افزوده شده و 3 حذف شده
  1. 28 3
      api_server.py

+ 28 - 3
api_server.py

@@ -18,6 +18,7 @@ import logging_service
 import prompt_templates
 from dialogue_manager import DialogueManager
 from history_dialogue_service import HistoryDialogueService
+from response_type_detector import ResponseTypeDetector
 from user_manager import MySQLUserManager, MySQLUserRelationManager
 from user_profile_extractor import UserProfileExtractor
 
@@ -113,7 +114,8 @@ def list_scenes():
     scenes = [
         {'scene': 'greeting', 'display_name': '问候'},
         {'scene': 'chitchat', 'display_name': '闲聊'},
-        {'scene': 'profile_extractor', 'display_name': '画像提取'}
+        {'scene': 'profile_extractor', 'display_name': '画像提取'},
+        {'scene': 'response_type_detector', 'display_name': '回复模态判断'}
     ]
     return wrap_response(200, data=scenes)
 
@@ -123,12 +125,14 @@ def get_base_prompt():
     prompt_map = {
         'greeting': prompt_templates.GENERAL_GREETING_PROMPT,
         'chitchat': prompt_templates.CHITCHAT_PROMPT_COZE,
-        'profile_extractor': prompt_templates.USER_PROFILE_EXTRACT_PROMPT
+        'profile_extractor': prompt_templates.USER_PROFILE_EXTRACT_PROMPT,
+        'response_type_detector': prompt_templates.RESPONSE_TYPE_DETECT_PROMPT
     }
     model_map = {
         'greeting': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
         'chitchat': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
-        'profile_extractor': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5
+        'profile_extractor': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
+        'response_type_detector': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5
     }
     if scene not in prompt_map:
         return wrap_response(404, msg='scene not found')
@@ -213,6 +217,24 @@ def run_chat_prompt(req_data):
     messages.extend(DialogueManager.compose_chat_messages_openai_compatible(dialogue_history, current_time_str))
     return run_openai_chat(messages, model_name, temperature=1, top_p=0.7, max_tokens=1024)
 
+def run_response_type_prompt(req_data):
+    prompt = req_data['prompt']
+    dialogue_history = req_data['dialogue_history']
+    model_name = req_data['model_name']
+
+    composed_dialogue = ResponseTypeDetector.compose_dialogue(dialogue_history[:-1])
+    next_message = DialogueManager.format_dialogue_content(dialogue_history[-1])
+    prompt = prompt.format(
+        dialogue_history=composed_dialogue,
+        message=next_message
+    )
+    messages = [
+        {'role': 'system', 'content': '你是一个专业的智能助手'},
+        {'role': 'user', 'content': prompt}
+    ]
+    return run_openai_chat(messages, model_name,temperature=0.2, max_tokens=128)
+
+
 @app.route('/api/runPrompt', methods=['POST'])
 def run_prompt():
     try:
@@ -222,6 +244,9 @@ def run_prompt():
         if scene == 'profile_extractor':
             response = run_extractor_prompt(req_data)
             return wrap_response(200, data=response)
+        elif scene == 'response_type_detector':
+            response = run_response_type_prompt(req_data)
+            return wrap_response(200, data=response)
         else:
             response = run_chat_prompt(req_data)
             return wrap_response(200, data=response.choices[0].message.content)