|
@@ -5,20 +5,36 @@
|
|
|
import json
|
|
|
from typing import Dict, Optional, List
|
|
|
|
|
|
-from pqai_agent import chat_service
|
|
|
-from pqai_agent import configs
|
|
|
-from pqai_agent.prompt_templates import USER_PROFILE_EXTRACT_PROMPT
|
|
|
+from pqai_agent import chat_service, configs
|
|
|
+from pqai_agent.prompt_templates import USER_PROFILE_EXTRACT_PROMPT, USER_PROFILE_EXTRACT_PROMPT_V2
|
|
|
from openai import OpenAI
|
|
|
from pqai_agent.logging_service import logger
|
|
|
+from pqai_agent.utils import prompt_utils
|
|
|
|
|
|
|
|
|
class UserProfileExtractor:
|
|
|
- def __init__(self):
|
|
|
- self.llm_client = OpenAI(
|
|
|
- api_key=chat_service.VOLCENGINE_API_TOKEN,
|
|
|
- base_url=chat_service.VOLCENGINE_BASE_URL
|
|
|
- )
|
|
|
- self.model_name = chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3
|
|
|
+ FIELDS = [
|
|
|
+ "name",
|
|
|
+ "preferred_nickname",
|
|
|
+ "gender",
|
|
|
+ "age",
|
|
|
+ "region",
|
|
|
+ "interests",
|
|
|
+ "health_conditions",
|
|
|
+ "interaction_frequency",
|
|
|
+ "flexible_params"
|
|
|
+ ]
|
|
|
+ def __init__(self, model_name=None, llm_client=None):
|
|
|
+ if not llm_client:
|
|
|
+ self.llm_client = OpenAI(
|
|
|
+ api_key=chat_service.VOLCENGINE_API_TOKEN,
|
|
|
+ base_url=chat_service.VOLCENGINE_BASE_URL
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ self.llm_client = llm_client
|
|
|
+ if not model_name:
|
|
|
+ model_name = chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3
|
|
|
+ self.model_name = model_name
|
|
|
|
|
|
@staticmethod
|
|
|
def get_extraction_function() -> Dict:
|
|
@@ -73,13 +89,14 @@ class UserProfileExtractor:
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- def generate_extraction_prompt(self, user_profile: Dict, dialogue_history: List[Dict]) -> str:
|
|
|
+ def generate_extraction_prompt(self, user_profile: Dict, dialogue_history: List[Dict], prompt_template = USER_PROFILE_EXTRACT_PROMPT) -> str:
|
|
|
"""
|
|
|
生成用于信息提取的系统提示词
|
|
|
"""
|
|
|
context = user_profile.copy()
|
|
|
context['dialogue_history'] = self.compose_dialogue(dialogue_history)
|
|
|
- return USER_PROFILE_EXTRACT_PROMPT.format(**context)
|
|
|
+ context['formatted_user_profile'] = prompt_utils.format_user_profile(user_profile)
|
|
|
+ return prompt_template.format(**context)
|
|
|
|
|
|
@staticmethod
|
|
|
def compose_dialogue(dialogue: List[Dict]) -> str:
|
|
@@ -130,15 +147,61 @@ class UserProfileExtractor:
|
|
|
logger.error(f"用户画像提取出错: {e}")
|
|
|
return None
|
|
|
|
|
|
+ def extract_profile_info_v2(self, user_profile: Dict, dialogue_history: List[Dict], prompt_template: Optional[str] = None) -> Optional[Dict]:
|
|
|
+ """
|
|
|
+ 使用JSON输出提取用户画像信息
|
|
|
+ :param user_profile:
|
|
|
+ :param dialogue_history:
|
|
|
+ :param prompt_template: 可选的自定义提示模板
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ if configs.get().get('debug_flags', {}).get('disable_llm_api_call', False):
|
|
|
+ return None
|
|
|
+
|
|
|
+ try:
|
|
|
+ logger.debug("try to extract profile from message: {}".format(dialogue_history))
|
|
|
+ prompt_template = prompt_template or USER_PROFILE_EXTRACT_PROMPT_V2
|
|
|
+ prompt = self.generate_extraction_prompt(user_profile, dialogue_history, prompt_template)
|
|
|
+ print(prompt)
|
|
|
+ response = self.llm_client.chat.completions.create(
|
|
|
+ model=self.model_name,
|
|
|
+ messages=[
|
|
|
+ {"role": "system", "content": '你是一个专业的用户画像分析助手。'},
|
|
|
+ {"role": "user", "content": prompt}
|
|
|
+ ],
|
|
|
+ temperature=0
|
|
|
+ )
|
|
|
+ json_data = response.choices[0].message.content \
|
|
|
+ .replace("```", "").replace("```json", "").strip()
|
|
|
+ try:
|
|
|
+ profile_info = json.loads(json_data)
|
|
|
+ except json.JSONDecodeError as e:
|
|
|
+ logger.error(f"Error in JSON decode: {e}, original input: {json_data}")
|
|
|
+ return None
|
|
|
+ return profile_info
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"用户画像提取出错: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
def merge_profile_info(self, existing_profile: Dict, new_info: Dict) -> Dict:
|
|
|
"""
|
|
|
合并新提取的用户信息到现有资料
|
|
|
"""
|
|
|
merged_profile = existing_profile.copy()
|
|
|
- merged_profile.update(new_info)
|
|
|
+ for field in new_info:
|
|
|
+ if field in self.FIELDS:
|
|
|
+ merged_profile[field] = new_info[field]
|
|
|
+ else:
|
|
|
+ logger.warning(f"Unknown field in new profile: {field}")
|
|
|
return merged_profile
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
+ from pqai_agent import configs
|
|
|
+ from pqai_agent import logging_service
|
|
|
+ logging_service.setup_root_logger()
|
|
|
+ config = configs.get()
|
|
|
+ config['debug_flags']['disable_llm_api_call'] = False
|
|
|
extractor = UserProfileExtractor()
|
|
|
current_profile = {
|
|
|
'name': '',
|
|
@@ -152,11 +215,11 @@ if __name__ == '__main__':
|
|
|
'interaction_frequency': 'medium'
|
|
|
}
|
|
|
messages= [
|
|
|
- {'role': 'user', 'content': "没有任何问题放心,不会骚扰你了,再见"}
|
|
|
+ {'role': 'user', 'content': "没有任何问题放心,以后不要再发了,再见"}
|
|
|
]
|
|
|
|
|
|
- resp = extractor.extract_profile_info(current_profile, messages)
|
|
|
- print(resp)
|
|
|
+ # resp = extractor.extract_profile_info_v2(current_profile, messages)
|
|
|
+ # logger.warning(resp)
|
|
|
message = "好的,孩子,我是老李头,今年68啦,住在北京海淀区。平时喜欢在微信上跟老伙伴们聊聊养生、下下象棋,偶尔也跟年轻人学学新鲜事儿。\n" \
|
|
|
"你叫我李叔就行,有啥事儿咱们慢慢聊啊\n" \
|
|
|
"哎,今儿个天气不错啊,我刚才还去楼下小公园溜达了一圈儿。碰到几个老伙计在打太极,我也跟着比划了两下,这老胳膊老腿的,原来老不舒服,活动活动舒坦多了!\n" \
|
|
@@ -165,9 +228,10 @@ if __name__ == '__main__':
|
|
|
messages = []
|
|
|
for line in message.split("\n"):
|
|
|
messages.append({'role': 'user', 'content': line})
|
|
|
- resp = extractor.extract_profile_info(current_profile, messages)
|
|
|
- print(resp)
|
|
|
- print(extractor.merge_profile_info(current_profile, resp))
|
|
|
+ resp = extractor.extract_profile_info_v2(current_profile, messages)
|
|
|
+ logger.warning(resp)
|
|
|
+ merged_profile = extractor.merge_profile_info(current_profile, resp)
|
|
|
+ logger.warning(merged_profile)
|
|
|
current_profile = {
|
|
|
'name': '李老头',
|
|
|
'preferred_nickname': '李叔',
|
|
@@ -179,6 +243,6 @@ if __name__ == '__main__':
|
|
|
'interests': ['养生', '下象棋'],
|
|
|
'interaction_frequency': 'medium'
|
|
|
}
|
|
|
- resp = extractor.extract_profile_info(current_profile, messages)
|
|
|
- print(resp)
|
|
|
- print(extractor.merge_profile_info(current_profile, resp))
|
|
|
+ resp = extractor.extract_profile_info_v2(merged_profile, messages)
|
|
|
+ logger.warning(resp)
|
|
|
+ logger.warning(extractor.merge_profile_info(current_profile, resp))
|