user_profile_extractor.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. import json
  5. from typing import Dict, Any, Optional
  6. import chat_service
  7. import configs
  8. from prompt_templates import USER_PROFILE_EXTRACT_PROMPT
  9. from openai import OpenAI
  10. from logging_service import logger
  11. import global_flags
  12. class UserProfileExtractor:
  13. def __init__(self):
  14. self.llm_client = OpenAI(
  15. api_key=chat_service.VOLCENGINE_API_TOKEN,
  16. base_url=chat_service.VOLCENGINE_BASE_URL
  17. )
  18. self.model_name = chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5
  19. def get_extraction_function(self) -> Dict:
  20. """
  21. 定义用于用户画像信息提取的Function Calling函数
  22. """
  23. return {
  24. "type": "function",
  25. "function": {
  26. "name": "update_user_profile",
  27. "description": "从用户对话中提取并更新用户的个人信息",
  28. "parameters": {
  29. "type": "object",
  30. "properties": {
  31. "name": {
  32. "type": "string",
  33. "description": "用户的姓名,如果能够准确识别"
  34. },
  35. "preferred_nickname": {
  36. "type": "string",
  37. "description": "用户希望对其的称呼,如果能够准确识别"
  38. },
  39. "gender": {
  40. "type": "string",
  41. "description": "用户的性别,男或女,如果不能准确识别则为未知"
  42. },
  43. "age": {
  44. "type": "integer",
  45. "description": "用户的年龄,如果能够准确识别"
  46. },
  47. "region": {
  48. "type": "string",
  49. "description": "用户常驻的地区,不是用户临时所在地"
  50. },
  51. "interests": {
  52. "type": "array",
  53. "items": {"type": "string"},
  54. "description": "用户提到的自己的兴趣爱好"
  55. },
  56. "health_conditions": {
  57. "type": "array",
  58. "items": {"type": "string"},
  59. "description": "用户提及的健康状况"
  60. },
  61. "interaction_frequency": {
  62. "type": "string",
  63. "description": "用户期望的交互频率。每2天联系小于1次为low,每天联系1次为medium,不再联系为stopped"
  64. }
  65. },
  66. "required": []
  67. }
  68. }
  69. }
  70. def generate_extraction_prompt(self, user_profile: Dict, dialogue_history: str) -> str:
  71. """
  72. 生成用于信息提取的系统提示词
  73. """
  74. context = user_profile.copy()
  75. context['dialogue_history'] = dialogue_history
  76. return USER_PROFILE_EXTRACT_PROMPT.format(**context)
  77. def extract_profile_info(self, user_profile, dialogue_history: str) -> Optional[Dict]:
  78. """
  79. 使用Function Calling提取用户画像信息
  80. """
  81. if configs.get().get('debug_flags', {}).get('disable_llm_api_call', False):
  82. return None
  83. try:
  84. logger.debug("try to extract profile from message: {}".format(dialogue_history))
  85. response = self.llm_client.chat.completions.create(
  86. model=self.model_name,
  87. messages=[
  88. {"role": "system", "content": '你是一个专业的用户画像分析助手。'},
  89. {"role": "user", "content": self.generate_extraction_prompt(user_profile, dialogue_history)}
  90. ],
  91. tools=[self.get_extraction_function()],
  92. temperature=0
  93. )
  94. # 解析Function Call的参数
  95. tool_calls = response.choices[0].message.tool_calls
  96. logger.debug(response)
  97. if tool_calls:
  98. function_call = tool_calls[0]
  99. if function_call.function.name == 'update_user_profile':
  100. try:
  101. profile_info = json.loads(function_call.function.arguments)
  102. return {k: v for k, v in profile_info.items() if v}
  103. except json.JSONDecodeError:
  104. logger.error("无法解析提取的用户信息")
  105. return None
  106. except Exception as e:
  107. logger.error(f"用户画像提取出错: {e}")
  108. return None
  109. def merge_profile_info(self, existing_profile: Dict, new_info: Dict) -> Dict:
  110. """
  111. 合并新提取的用户信息到现有资料
  112. """
  113. merged_profile = existing_profile.copy()
  114. merged_profile.update(new_info)
  115. return merged_profile
  116. if __name__ == '__main__':
  117. extractor = UserProfileExtractor()
  118. current_profile = {
  119. 'name': '',
  120. 'preferred_nickname': '李叔',
  121. "gender": "男",
  122. 'age': 0,
  123. 'region': '北京',
  124. 'health_conditions': [],
  125. 'medications': [],
  126. 'interests': [],
  127. 'interaction_frequency': 'medium'
  128. }
  129. message = "没有任何问题放心,不会骚扰你了,再见"
  130. resp = extractor.extract_profile_info(current_profile, message)
  131. print(resp)
  132. message = "好的,孩子,我是老李头,今年68啦,住在北京海淀区。平时喜欢在微信上跟老伙伴们聊聊养生、下下象棋,偶尔也跟年轻人学学新鲜事儿。\n" \
  133. "你叫我李叔就行,有啥事儿咱们慢慢聊啊\n" \
  134. "哎,今儿个天气不错啊,我刚才还去楼下小公园溜达了一圈儿。碰到几个老伙计在打太极,我也跟着比划了两下,这老胳膊老腿的,原来老不舒服,活动活动舒坦多了!\n" \
  135. "你吃饭了没?我们这儿中午吃的打卤面,老伴儿做的,香得很!这人老了就爱念叨些家长里短的,你可别嫌我啰嗦啊。\n" \
  136. "对了,最近我孙子教我发语音,比打字方便多啦!就是有时候一激动,说话声音太大,把手机都给震得嗡嗡响\n"
  137. resp = extractor.extract_profile_info(current_profile, message)
  138. print(resp)
  139. print(extractor.merge_profile_info(current_profile, resp))
  140. current_profile = {
  141. 'name': '李老头',
  142. 'preferred_nickname': '李叔',
  143. "gender": "男",
  144. 'age': 68,
  145. 'region': '北京市海淀区',
  146. 'health_conditions': [],
  147. 'medications': [],
  148. 'interests': ['养生', '下象棋']
  149. }
  150. resp = extractor.extract_profile_info(current_profile, message)
  151. print(resp)
  152. print(extractor.merge_profile_info(current_profile, resp))