Просмотр исходного кода

Update api_server: support run_prompt for mutiple scenes

StrayWarrior 1 неделя назад
Родитель
Сommit
4bc7bd6300
1 измененных файлов с 74 добавлено и 45 удалено
  1. 74 45
      api_server.py

+ 74 - 45
api_server.py

@@ -12,11 +12,13 @@ from openai import OpenAI
 
 import chat_service
 import configs
+import json
 import logging_service
 import prompt_templates
 from dialogue_manager import DialogueManager
 from history_dialogue_service import HistoryDialogueService
 from user_manager import MySQLUserManager, MySQLUserRelationManager
+from user_profile_extractor import UserProfileExtractor
 
 app = Flask('agent_api_server')
 
@@ -134,60 +136,87 @@ def get_base_prompt():
     }
     return wrap_response(200, data=data)
 
-def get_llm_response(model_name, messages):
-    pass
+def run_openai_chat(messages, model_name, **kwargs):
+    volcengine_models = [
+        chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
+        chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
+        chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3
+    ]
+    deepseek_models = [
+        chat_service.DEEPSEEK_CHAT_MODEL,
+    ]
+    if model_name in volcengine_models:
+        llm_client = OpenAI(api_key=chat_service.VOLCENGINE_API_TOKEN, base_url=chat_service.VOLCENGINE_BASE_URL)
+        response = llm_client.chat.completions.create(
+            messages=messages, model=model_name, **kwargs)
+        return response
+    elif model_name in deepseek_models:
+        llm_client = OpenAI(api_key=chat_service.DEEPSEEK_API_TOKEN, base_url=chat_service.DEEPSEEK_BASE_URL)
+        response = llm_client.chat.completions.create(
+            messages=messages, model=model_name, temperature=1, top_p=0.7, max_tokens=1024)
+        return response
+    else:
+        raise Exception('model not supported')
 
-def run_chat_prompt():
-    pass
+def run_extractor_prompt(req_data):
+    prompt = req_data['prompt']
+    user_profile = req_data['user_profile']
+    staff_profile = req_data['staff_profile']
+    dialogue_history = req_data['dialogue_history']
+    model_name = req_data['model_name']
+    prompt_context = {**staff_profile,
+                      **user_profile,
+                      'dialogue_history': UserProfileExtractor.compose_dialogue(dialogue_history)}
+    prompt = prompt.format(**prompt_context)
+    messages = [
+        {"role": "system", "content": '你是一个专业的用户画像分析助手。'},
+        {"role": "user", "content": prompt}
+    ]
+    tools = [UserProfileExtractor.get_extraction_function()]
+    response = run_openai_chat(messages, model_name, tools=tools, temperature=0)
+    tool_calls = response.choices[0].message.tool_calls
+    if tool_calls:
+        function_call = tool_calls[0]
+        if function_call.function.name == 'update_user_profile':
+            profile_info = json.loads(function_call.function.arguments)
+            return {k: v for k, v in profile_info.items() if v}
+    else:
+        raise Exception("llm does not return tool_calls")
 
-def run_extractor_prompt():
-    pass
+def run_chat_prompt(req_data):
+    prompt = req_data['prompt']
+    staff_profile = req_data['staff_profile']
+    user_profile = req_data['user_profile']
+    dialogue_history = req_data['dialogue_history']
+    model_name = req_data['model_name']
+    current_timestamp = req_data['current_timestamp'] / 1000
+    prompt_context = {**staff_profile, **user_profile}
+    current_hour = datetime.fromtimestamp(current_timestamp).hour
+    prompt_context['last_interaction_interval'] = 0
+    prompt_context['current_time_period'] = DialogueManager.get_time_context(current_hour)
+    prompt_context['current_hour'] = current_hour
+    prompt_context['if_first_interaction'] = False if dialogue_history else True
+
+    current_time_str = datetime.fromtimestamp(current_timestamp).strftime('%Y-%m-%d %H:%M:%S')
+    system_prompt = {
+        'role': 'system',
+        'content': prompt.format(**prompt_context)
+    }
+    messages = [system_prompt]
+    messages.extend(DialogueManager.compose_chat_messages_openai_compatible(dialogue_history, current_time_str))
+    return run_openai_chat(messages, model_name, temperature=1, top_p=0.7, max_tokens=1024)
 
 @app.route('/api/runPrompt', methods=['POST'])
 def run_prompt():
     try:
         req_data = request.json
         scene = req_data['scene']
-        prompt = req_data['prompt']
-        staff_profile = req_data['staff_profile']
-        user_profile = req_data['user_profile']
-        dialogue_history = req_data['dialogue_history']
-        model_name = req_data['model_name']
-        current_timestamp = req_data['current_timestamp'] / 1000
-        prompt_context = {**staff_profile, **user_profile}
-        current_hour = datetime.fromtimestamp(current_timestamp).hour
-        prompt_context['last_interaction_interval'] = 0
-        prompt_context['current_time_period'] = DialogueManager.get_time_context(current_hour)
-        prompt_context['current_hour'] = current_hour
-        prompt_context['if_first_interaction'] = False if dialogue_history else True
-        volcengine_models = [
-            chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
-            chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
-            chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3
-        ]
-        deepseek_models = [
-            chat_service.DEEPSEEK_CHAT_MODEL,
-        ]
-        current_time_str = datetime.fromtimestamp(current_timestamp).strftime('%Y-%m-%d %H:%M:%S')
-        system_prompt = {
-            'role': 'system',
-            'content': prompt.format(**prompt_context)
-        }
-        messages = []
-        messages.append(system_prompt)
-        messages.extend(DialogueManager.compose_chat_messages_openai_compatible(dialogue_history, current_time_str))
-        if model_name in volcengine_models:
-            llm_client = OpenAI(api_key=chat_service.VOLCENGINE_API_TOKEN, base_url=chat_service.VOLCENGINE_BASE_URL)
-            response = llm_client.chat.completions.create(
-                messages=messages, model=model_name, temperature=1, top_p=0.7, max_tokens=1024)
-            return wrap_response(200, data=response.choices[0].message.content)
-        elif model_name in deepseek_models:
-            llm_client = OpenAI(api_key=chat_service.DEEPSEEK_API_TOKEN, base_url=chat_service.DEEPSEEK_BASE_URL)
-            response = llm_client.chat.completions.create(
-                messages=messages, model=model_name, temperature=1, top_p=0.7, max_tokens=1024)
-            return wrap_response(200, data=response.choices[0].message.content)
+        if scene == 'profile_extractor':
+            response = run_extractor_prompt(req_data)
+            return wrap_response(200, data=response)
         else:
-            return wrap_response(400, msg='model not supported')
+            response = run_chat_prompt(req_data)
+            return wrap_response(200, data=response.choices[0].message.content)
     except Exception as e:
         return wrap_response(500, msg='Error: {}'.format(e))