浏览代码

Update API server: update profile extractor

StrayWarrior 2 天之前
父节点
当前提交
8f0dfd3bae
共有 2 个文件被更改,包括 14 次插入28 次删除
  1. 1 1
      pqai_agent_server/api_server.py
  2. 13 27
      pqai_agent_server/utils/prompt_util.py

+ 1 - 1
pqai_agent_server/api_server.py

@@ -139,7 +139,7 @@ def get_base_prompt():
     prompt_map = {
         'greeting': prompt_templates.GENERAL_GREETING_PROMPT,
         'chitchat': prompt_templates.CHITCHAT_PROMPT_COZE,
-        'profile_extractor': prompt_templates.USER_PROFILE_EXTRACT_PROMPT,
+        'profile_extractor': prompt_templates.USER_PROFILE_EXTRACT_PROMPT_V2,
         'response_type_detector': prompt_templates.RESPONSE_TYPE_DETECT_PROMPT,
         'custom_debugging': '',
     }

+ 13 - 27
pqai_agent_server/utils/prompt_util.py

@@ -41,8 +41,7 @@ def compose_openai_chat_messages_no_time(dialogue_history, multimodal=False):
             messages.append({"role": role, "content": f'{entry["content"]}'})
     return messages
 
-
-def run_openai_chat(messages, model_name, **kwargs):
+def create_llm_client(model_name):
     volcengine_models = [
         chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
         chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
@@ -72,6 +71,11 @@ def run_openai_chat(messages, model_name, **kwargs):
         )
     else:
         raise Exception("model not supported")
+    return llm_client
+
+
+def run_openai_chat(messages, model_name, **kwargs):
+    llm_client = create_llm_client(model_name)
     response = llm_client.chat.completions.create(
         messages=messages, model=model_name, **kwargs
     )
@@ -79,36 +83,18 @@ def run_openai_chat(messages, model_name, **kwargs):
     return response
 
 
-def run_extractor_prompt(req_data):
+def run_extractor_prompt(req_data) -> Dict[str, str]:
     prompt = req_data["prompt"]
     user_profile = req_data["user_profile"]
-    staff_profile = req_data["staff_profile"]
     dialogue_history = req_data["dialogue_history"]
     model_name = req_data["model_name"]
-    prompt_context = {
-        "formatted_staff_profile": format_agent_profile(staff_profile),
-        **user_profile,
-        "dialogue_history": UserProfileExtractor.compose_dialogue(dialogue_history),
-    }
-    prompt = prompt.format(**prompt_context)
-    messages = [
-        {"role": "system", "content": "你是一个专业的用户画像分析助手。"},
-        {"role": "user", "content": prompt},
-    ]
-    tools = [UserProfileExtractor.get_extraction_function()]
-    response = run_openai_chat(messages, model_name, tools=tools, temperature=0)
-    tool_calls = response.choices[0].message.tool_calls
-    if tool_calls:
-        function_call = tool_calls[0]
-        if function_call.function.name == "update_user_profile":
-            profile_info = json.loads(function_call.function.arguments)
-            return {k: v for k, v in profile_info.items() if v}
-        else:
-            logger.error("llm does not return update_user_profile")
-            return {}
-    else:
+    llm_client = create_llm_client(model_name)
+    extractor = UserProfileExtractor(model_name=model_name, llm_client=llm_client)
+    profile_to_update = extractor.extract_profile_info_v2(user_profile, dialogue_history, prompt)
+    logger.info(profile_to_update)
+    if not profile_to_update:
         return {}
-
+    return profile_to_update
 
 def run_chat_prompt(req_data):
     prompt = req_data["prompt"]