user_profile_extractor.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. import json
  5. from typing import Dict, Optional, List
  6. from pqai_agent import chat_service, configs
  7. from pqai_agent.prompt_templates import USER_PROFILE_EXTRACT_PROMPT, USER_PROFILE_EXTRACT_PROMPT_V2
  8. from openai import OpenAI
  9. from pqai_agent.logging_service import logger
  10. from pqai_agent.utils import prompt_utils
  11. class UserProfileExtractor:
  12. FIELDS = [
  13. "name",
  14. "preferred_nickname",
  15. "gender",
  16. "age",
  17. "region",
  18. "interests",
  19. "health_conditions",
  20. "interaction_frequency",
  21. "flexible_params"
  22. ]
  23. def __init__(self, model_name=None, llm_client=None):
  24. if not llm_client:
  25. self.llm_client = OpenAI(
  26. api_key=chat_service.VOLCENGINE_API_TOKEN,
  27. base_url=chat_service.VOLCENGINE_BASE_URL
  28. )
  29. else:
  30. self.llm_client = llm_client
  31. if not model_name:
  32. model_name = chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3
  33. self.model_name = model_name
  34. @staticmethod
  35. def get_extraction_function() -> Dict:
  36. """
  37. 定义用于用户画像信息提取的Function Calling函数
  38. """
  39. return {
  40. "type": "function",
  41. "function": {
  42. "name": "update_user_profile",
  43. "description": "从用户对话中提取并更新用户的个人信息",
  44. "parameters": {
  45. "type": "object",
  46. "properties": {
  47. "name": {
  48. "type": "string",
  49. "description": "用户的姓名,如果能够准确识别"
  50. },
  51. "preferred_nickname": {
  52. "type": "string",
  53. "description": "用户希望客服对用户的称呼,如果用户明确提到"
  54. },
  55. "gender": {
  56. "type": "string",
  57. "description": "用户的性别,男或女,如果不能准确识别则为未知"
  58. },
  59. "age": {
  60. "type": "integer",
  61. "description": "用户的年龄,如果能够准确识别"
  62. },
  63. "region": {
  64. "type": "string",
  65. "description": "用户常驻的地区,不是用户临时所在地"
  66. },
  67. "interests": {
  68. "type": "array",
  69. "items": {"type": "string"},
  70. "description": "用户提到的自己的兴趣爱好"
  71. },
  72. "health_conditions": {
  73. "type": "array",
  74. "items": {"type": "string"},
  75. "description": "用户提及的健康状况"
  76. },
  77. "interaction_frequency": {
  78. "type": "string",
  79. "description": "用户期望的交互频率。每2天联系小于1次为low,每天联系1次为medium,未来均不再联系为stopped"
  80. }
  81. },
  82. "required": []
  83. }
  84. }
  85. }
  86. def generate_extraction_prompt(self, user_profile: Dict, dialogue_history: List[Dict], prompt_template = USER_PROFILE_EXTRACT_PROMPT) -> str:
  87. """
  88. 生成用于信息提取的系统提示词
  89. """
  90. context = user_profile.copy()
  91. context['dialogue_history'] = self.compose_dialogue(dialogue_history)
  92. context['formatted_user_profile'] = prompt_utils.format_user_profile(user_profile)
  93. return prompt_template.format(**context)
  94. @staticmethod
  95. def compose_dialogue(dialogue: List[Dict]) -> str:
  96. role_map = {'user': '用户', 'assistant': '客服'}
  97. messages = []
  98. for msg in dialogue:
  99. if not msg['content']:
  100. continue
  101. if msg['role'] not in role_map:
  102. continue
  103. messages.append('[{}] {}'.format(role_map[msg['role']], msg['content']))
  104. return '\n'.join(messages)
  105. def extract_profile_info(self, user_profile, dialogue_history: List[Dict]) -> Optional[Dict]:
  106. """
  107. 使用Function Calling提取用户画像信息
  108. """
  109. if configs.get().get('debug_flags', {}).get('disable_llm_api_call', False):
  110. return None
  111. try:
  112. logger.debug("try to extract profile from message: {}".format(dialogue_history))
  113. prompt = self.generate_extraction_prompt(user_profile, dialogue_history)
  114. response = self.llm_client.chat.completions.create(
  115. model=self.model_name,
  116. messages=[
  117. {"role": "system", "content": '你是一个专业的用户画像分析助手。'},
  118. {"role": "user", "content": prompt}
  119. ],
  120. tools=[self.get_extraction_function()],
  121. temperature=0
  122. )
  123. # 解析Function Call的参数
  124. tool_calls = response.choices[0].message.tool_calls
  125. logger.debug(response)
  126. if tool_calls:
  127. function_call = tool_calls[0]
  128. if function_call.function.name == 'update_user_profile':
  129. try:
  130. profile_info = json.loads(function_call.function.arguments)
  131. return {k: v for k, v in profile_info.items() if v}
  132. except json.JSONDecodeError:
  133. logger.error("无法解析提取的用户信息")
  134. return None
  135. except Exception as e:
  136. logger.error(f"用户画像提取出错: {e}")
  137. return None
  138. def extract_profile_info_v2(self, user_profile: Dict, dialogue_history: List[Dict], prompt_template: Optional[str] = None) -> Optional[Dict]:
  139. """
  140. 使用JSON输出提取用户画像信息
  141. :param user_profile:
  142. :param dialogue_history:
  143. :param prompt_template: 可选的自定义提示模板
  144. :return:
  145. """
  146. if configs.get().get('debug_flags', {}).get('disable_llm_api_call', False):
  147. return None
  148. try:
  149. logger.debug("try to extract profile from message: {}".format(dialogue_history))
  150. prompt_template = prompt_template or USER_PROFILE_EXTRACT_PROMPT_V2
  151. prompt = self.generate_extraction_prompt(user_profile, dialogue_history, prompt_template)
  152. print(prompt)
  153. response = self.llm_client.chat.completions.create(
  154. model=self.model_name,
  155. messages=[
  156. {"role": "system", "content": '你是一个专业的用户画像分析助手。'},
  157. {"role": "user", "content": prompt}
  158. ],
  159. temperature=0
  160. )
  161. json_data = response.choices[0].message.content \
  162. .replace("```", "").replace("```json", "").strip()
  163. try:
  164. profile_info = json.loads(json_data)
  165. except json.JSONDecodeError as e:
  166. logger.error(f"Error in JSON decode: {e}, original input: {json_data}")
  167. return None
  168. return profile_info
  169. except Exception as e:
  170. logger.error(f"用户画像提取出错: {e}")
  171. return None
  172. def merge_profile_info(self, existing_profile: Dict, new_info: Dict) -> Dict:
  173. """
  174. 合并新提取的用户信息到现有资料
  175. """
  176. merged_profile = existing_profile.copy()
  177. for field in new_info:
  178. if field in self.FIELDS:
  179. merged_profile[field] = new_info[field]
  180. else:
  181. logger.warning(f"Unknown field in new profile: {field}")
  182. return merged_profile
  183. if __name__ == '__main__':
  184. from pqai_agent import configs
  185. from pqai_agent import logging_service
  186. logging_service.setup_root_logger()
  187. config = configs.get()
  188. config['debug_flags']['disable_llm_api_call'] = False
  189. extractor = UserProfileExtractor()
  190. current_profile = {
  191. 'name': '',
  192. 'preferred_nickname': '李叔',
  193. "gender": "男",
  194. 'age': 0,
  195. 'region': '北京',
  196. 'health_conditions': [],
  197. 'medications': [],
  198. 'interests': [],
  199. 'interaction_frequency': 'medium'
  200. }
  201. messages= [
  202. {'role': 'user', 'content': "没有任何问题放心,以后不要再发了,再见"}
  203. ]
  204. # resp = extractor.extract_profile_info_v2(current_profile, messages)
  205. # logger.warning(resp)
  206. message = "好的,孩子,我是老李头,今年68啦,住在北京海淀区。平时喜欢在微信上跟老伙伴们聊聊养生、下下象棋,偶尔也跟年轻人学学新鲜事儿。\n" \
  207. "你叫我李叔就行,有啥事儿咱们慢慢聊啊\n" \
  208. "哎,今儿个天气不错啊,我刚才还去楼下小公园溜达了一圈儿。碰到几个老伙计在打太极,我也跟着比划了两下,这老胳膊老腿的,原来老不舒服,活动活动舒坦多了!\n" \
  209. "你吃饭了没?我们这儿中午吃的打卤面,老伴儿做的,香得很!这人老了就爱念叨些家长里短的,你可别嫌我啰嗦啊。\n" \
  210. "对了,最近我孙子教我发语音,比打字方便多啦!就是有时候一激动,说话声音太大,把手机都给震得嗡嗡响\n"
  211. messages = []
  212. for line in message.split("\n"):
  213. messages.append({'role': 'user', 'content': line})
  214. resp = extractor.extract_profile_info_v2(current_profile, messages)
  215. logger.warning(resp)
  216. merged_profile = extractor.merge_profile_info(current_profile, resp)
  217. logger.warning(merged_profile)
  218. current_profile = {
  219. 'name': '李老头',
  220. 'preferred_nickname': '李叔',
  221. "gender": "男",
  222. 'age': 68,
  223. 'region': '北京市海淀区',
  224. 'health_conditions': [],
  225. 'medications': [],
  226. 'interests': ['养生', '下象棋'],
  227. 'interaction_frequency': 'medium'
  228. }
  229. resp = extractor.extract_profile_info_v2(merged_profile, messages)
  230. logger.warning(resp)
  231. logger.warning(extractor.merge_profile_info(current_profile, resp))