|
@@ -3,7 +3,7 @@
|
|
|
# vim:fenc=utf-8
|
|
|
|
|
|
import json
|
|
|
-from typing import Dict, Any, Optional
|
|
|
+from typing import Dict, Any, Optional, List
|
|
|
|
|
|
import chat_service
|
|
|
import configs
|
|
@@ -18,7 +18,7 @@ class UserProfileExtractor:
|
|
|
api_key=chat_service.VOLCENGINE_API_TOKEN,
|
|
|
base_url=chat_service.VOLCENGINE_BASE_URL
|
|
|
)
|
|
|
- self.model_name = chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5
|
|
|
+ self.model_name = chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3
|
|
|
|
|
|
def get_extraction_function(self) -> Dict:
|
|
|
"""
|
|
@@ -38,7 +38,7 @@ class UserProfileExtractor:
|
|
|
},
|
|
|
"preferred_nickname": {
|
|
|
"type": "string",
|
|
|
- "description": "用户希望对其的称呼,如果能够准确识别"
|
|
|
+ "description": "用户希望对其的称呼,如果用户明确提到"
|
|
|
},
|
|
|
"gender": {
|
|
|
"type": "string",
|
|
@@ -72,15 +72,27 @@ class UserProfileExtractor:
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- def generate_extraction_prompt(self, user_profile: Dict, dialogue_history: str) -> str:
|
|
|
+ def generate_extraction_prompt(self, user_profile: Dict, dialogue_history: List[Dict]) -> str:
|
|
|
"""
|
|
|
生成用于信息提取的系统提示词
|
|
|
"""
|
|
|
context = user_profile.copy()
|
|
|
- context['dialogue_history'] = dialogue_history
|
|
|
+ context['dialogue_history'] = self.compose_dialogue(dialogue_history)
|
|
|
return USER_PROFILE_EXTRACT_PROMPT.format(**context)
|
|
|
|
|
|
- def extract_profile_info(self, user_profile, dialogue_history: str) -> Optional[Dict]:
|
|
|
+ @staticmethod
|
|
|
+ def compose_dialogue(dialogue: List[Dict]) -> str:
|
|
|
+ role_map = {'user': '用户', 'assistant': '客服'}
|
|
|
+ messages = []
|
|
|
+ for msg in dialogue:
|
|
|
+ if not msg['content']:
|
|
|
+ continue
|
|
|
+ if msg['role'] not in role_map:
|
|
|
+ continue
|
|
|
+ messages.append('[{}] {}'.format(role_map[msg['role']], msg['content']))
|
|
|
+ return '\n'.join(messages)
|
|
|
+
|
|
|
+ def extract_profile_info(self, user_profile, dialogue_history: List[Dict]) -> Optional[Dict]:
|
|
|
"""
|
|
|
使用Function Calling提取用户画像信息
|
|
|
"""
|
|
@@ -89,11 +101,12 @@ class UserProfileExtractor:
|
|
|
|
|
|
try:
|
|
|
logger.debug("try to extract profile from message: {}".format(dialogue_history))
|
|
|
+ prompt = self.generate_extraction_prompt(user_profile, dialogue_history)
|
|
|
response = self.llm_client.chat.completions.create(
|
|
|
model=self.model_name,
|
|
|
messages=[
|
|
|
{"role": "system", "content": '你是一个专业的用户画像分析助手。'},
|
|
|
- {"role": "user", "content": self.generate_extraction_prompt(user_profile, dialogue_history)}
|
|
|
+ {"role": "user", "content": prompt}
|
|
|
],
|
|
|
tools=[self.get_extraction_function()],
|
|
|
temperature=0
|