user_manager.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. from logging_service import logger
  5. from typing import Dict, Optional, Tuple, Any, List
  6. import json
  7. import time
  8. import os
  9. import abc
  10. import pymysql.cursors
  11. import configs
  12. from database import MySQLManager
  13. class UserManager(abc.ABC):
  14. @abc.abstractmethod
  15. def get_user_profile(self, user_id) -> Dict:
  16. pass
  17. @abc.abstractmethod
  18. def save_user_profile(self, user_id, profile: Dict) -> None:
  19. pass
  20. @abc.abstractmethod
  21. def list_all_users(self):
  22. pass
  23. @abc.abstractmethod
  24. def get_staff_profile(self, staff_id) -> Dict:
  25. #FIXME(zhoutian): 重新设计用户和员工数据管理模型
  26. pass
  27. @staticmethod
  28. def get_default_profile(**kwargs) -> Dict:
  29. default_profile = {
  30. "name": "",
  31. "nickname": "",
  32. "avatar": "",
  33. "preferred_nickname": "",
  34. "gender": "未知",
  35. "age": 0,
  36. "region": '',
  37. "interests": [],
  38. "family_members": {},
  39. "health_conditions": [],
  40. "medications": [],
  41. "reminder_preferences": {
  42. "medication": True,
  43. "health": True,
  44. "weather": True,
  45. "news": False
  46. },
  47. "interaction_style": "standard", # standard, verbose, concise
  48. "interaction_frequency": "medium", # low, medium, high
  49. "last_topics": [],
  50. "created_at": int(time.time() * 1000),
  51. "human_intervention_history": []
  52. }
  53. for key, value in kwargs.items():
  54. if key in default_profile:
  55. default_profile[key] = value
  56. return default_profile
  57. def list_users(self, **kwargs) -> List[Dict]:
  58. pass
  59. class UserRelationManager(abc.ABC):
  60. @abc.abstractmethod
  61. def list_staffs(self):
  62. pass
  63. @abc.abstractmethod
  64. def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
  65. pass
  66. @abc.abstractmethod
  67. def list_staff_users(self) -> List[Dict]:
  68. pass
  69. @abc.abstractmethod
  70. def get_user_tags(self, user_id: str) -> List[str]:
  71. pass
  72. class LocalUserManager(UserManager):
  73. def get_user_profile(self, user_id) -> Dict:
  74. """加载用户个人资料,如不存在则创建默认资料。主要用于本地调试"""
  75. default_profile = self.get_default_profile()
  76. try:
  77. with open(f"user_profiles/{user_id}.json", "r", encoding="utf-8") as f:
  78. profile = json.load(f)
  79. entry_added = False
  80. for key, value in default_profile.items():
  81. if key not in profile:
  82. logger.debug(f"user[{user_id}] add profile key[{key}] value[{value}]")
  83. profile[key] = value
  84. entry_added = True
  85. if entry_added:
  86. self.save_user_profile(user_id, profile)
  87. return profile
  88. except FileNotFoundError:
  89. # 创建默认用户资料
  90. self.save_user_profile(user_id, default_profile)
  91. return default_profile
  92. def save_user_profile(self, user_id, profile: Dict) -> None:
  93. if not user_id:
  94. raise Exception("Invalid user_id: {}".format(user_id))
  95. with open(f"user_profiles/{user_id}.json", "w", encoding="utf-8") as f:
  96. json.dump(profile, f, ensure_ascii=False, indent=2)
  97. def list_all_users(self):
  98. user_ids = []
  99. for root, dirs, files in os.walk('user_profiles/'):
  100. for file in files:
  101. if file.endswith('.json'):
  102. user_ids.append(os.path.splitext(file)[0])
  103. return user_ids
  104. def get_staff_profile(self, staff_id) -> Dict:
  105. # for test only
  106. return {
  107. 'agent_name': '小芳',
  108. 'agent_gender': '女',
  109. 'agent_age': 30,
  110. 'agent_region': '北京'
  111. }
  112. def list_users(self, **kwargs) -> List[Dict]:
  113. pass
  114. class MySQLUserManager(UserManager):
  115. PROFILE_EXCLUDE_ITEMS = ['avatar', ]
  116. def __init__(self, db_config, table_name, staff_table):
  117. self.db = MySQLManager(db_config)
  118. self.table_name = table_name
  119. self.staff_table = staff_table
  120. def get_user_profile(self, user_id) -> Dict:
  121. sql = f"SELECT name, wxid, profile_data_v1, gender, iconurl as avatar" \
  122. f" FROM {self.table_name} WHERE third_party_user_id = {user_id}"
  123. data = self.db.select(sql, pymysql.cursors.DictCursor)
  124. if not data:
  125. logger.error(f"user[{user_id}] not found")
  126. return {}
  127. data = data[0]
  128. gender_map = {0: '未知', 1: '男', 2: '女', None: '未知'}
  129. gender = gender_map[data['gender']]
  130. default_profile = self.get_default_profile(nickname=data['name'], gender=gender, avatar=data['avatar'])
  131. if not data['profile_data_v1']:
  132. logger.warning(f"user[{user_id}] profile not found, create a default one")
  133. self.save_user_profile(user_id, default_profile)
  134. return default_profile
  135. else:
  136. profile = json.loads(data['profile_data_v1'])
  137. # 资料条目有增加时,需合并更新
  138. entry_added = False
  139. for key, value in default_profile.items():
  140. if key not in profile:
  141. logger.debug(f"user[{user_id}] add profile key[{key}] value[{value}]")
  142. profile[key] = value
  143. entry_added = True
  144. if entry_added:
  145. self.save_user_profile(user_id, profile)
  146. return profile
  147. def save_user_profile(self, user_id, profile: Dict) -> None:
  148. if not user_id:
  149. raise Exception("Invalid user_id: {}".format(user_id))
  150. if configs.get().get('debug_flags', {}).get('disable_database_write', False):
  151. return
  152. profile = profile.copy()
  153. for name in self.PROFILE_EXCLUDE_ITEMS:
  154. profile.pop(name, None)
  155. sql = f"UPDATE {self.table_name} SET profile_data_v1 = %s WHERE third_party_user_id = {user_id}"
  156. self.db.execute(sql, (json.dumps(profile),))
  157. def list_all_users(self):
  158. sql = f"SELECT third_party_user_id FROM {self.table_name}"
  159. data = self.db.select(sql, pymysql.cursors.DictCursor)
  160. return [user['third_party_user_id'] for user in data]
  161. def get_staff_profile(self, staff_id) -> Dict:
  162. if not self.staff_table:
  163. raise Exception("staff_table is not set")
  164. sql = f"SELECT agent_name, agent_gender, agent_age, agent_region " \
  165. f"FROM {self.staff_table} WHERE third_party_user_id = '{staff_id}'"
  166. data = self.db.select(sql, pymysql.cursors.DictCursor)
  167. if not data:
  168. logger.error(f"staff[{staff_id}] not found")
  169. return {}
  170. profile = data[0]
  171. # 转换性别格式
  172. gender_map = {0: '未知', 1: '男', 2: '女', None: '未知'}
  173. profile['agent_gender'] = gender_map[profile['agent_gender']]
  174. return profile
  175. def list_users(self, **kwargs) -> List[Dict]:
  176. user_union_id = kwargs.get('user_union_id', None)
  177. user_name = kwargs.get('user_name', None)
  178. if not user_union_id and not user_name:
  179. raise Exception("user_union_id or user_name is required")
  180. sql = f"SELECT third_party_user_id, wxid, name, iconurl, gender FROM {self.table_name} WHERE 1=1 "
  181. if user_name:
  182. sql += f"AND name = '{user_name}' COLLATE utf8mb4_bin "
  183. if user_union_id:
  184. sql += f"AND wxid = '{user_union_id}' "
  185. data = self.db.select(sql, pymysql.cursors.DictCursor)
  186. return data
  187. class MySQLUserRelationManager(UserRelationManager):
  188. def __init__(self, agent_db_config, wecom_db_config,
  189. agent_staff_table, agent_user_table,
  190. staff_table, relation_table, user_table):
  191. # FIXME(zhoutian): 因为现在数据库表不统一,需要从两个库读取
  192. self.agent_db = MySQLManager(agent_db_config)
  193. self.wecom_db = MySQLManager(wecom_db_config)
  194. self.agent_staff_table = agent_staff_table
  195. self.staff_table = staff_table
  196. self.relation_table = relation_table
  197. self.agent_user_table = agent_user_table
  198. self.user_table = user_table
  199. def list_staffs(self):
  200. sql = f"SELECT third_party_user_id, name, wxid, agent_name FROM {self.agent_staff_table} WHERE status = 1"
  201. data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
  202. return data
  203. def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
  204. return []
  205. def list_staff_users(self):
  206. # FIXME(zhoutian)
  207. # 测试期间逻辑,只取一个账号
  208. sql = (f"SELECT third_party_user_id, wxid FROM {self.agent_staff_table} WHERE status = 1"
  209. f" AND third_party_user_id in ('1688854492669990', '1688855931724582')")
  210. agent_staff_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
  211. if not agent_staff_data:
  212. return []
  213. ret = []
  214. for agent_staff in agent_staff_data:
  215. wxid = agent_staff['wxid']
  216. sql = f"SELECT id FROM {self.staff_table} WHERE carrier_id = '{wxid}'"
  217. staff_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
  218. if not staff_data:
  219. logger.error(f"staff[{wxid}] not found in wecom database")
  220. continue
  221. staff_id = staff_data[0]['id']
  222. sql = f"SELECT user_id FROM {self.relation_table} WHERE staff_id = '{staff_id}' AND is_delete = 0"
  223. user_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
  224. if not user_data:
  225. logger.warning(f"staff[{wxid}] has no user")
  226. continue
  227. user_ids = tuple(user['user_id'] for user in user_data)
  228. # FIXME(zhoutian): 测试期间临时逻辑
  229. if agent_staff['third_party_user_id'] == '1688854492669990':
  230. sql = f"SELECT union_id FROM {self.user_table} WHERE id IN {str(user_ids)} AND union_id is not null"
  231. else:
  232. sql = f"SELECT union_id FROM {self.user_table} WHERE id IN {str(user_ids)} AND union_id is not null" \
  233. f" AND id in (SELECT distinct user_id FROM we_com_user_with_tag WHERE tag_id = 15 and is_delete = 0)"
  234. user_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
  235. if not user_data:
  236. logger.warning(f"staff[{wxid}] users not found in wecom database")
  237. continue
  238. user_union_ids = tuple(user['union_id'] for user in user_data)
  239. batch_size = 100
  240. n_batches = (len(user_union_ids) + batch_size - 1) // batch_size
  241. agent_user_data = []
  242. for i in range(n_batches):
  243. idx_begin = i * batch_size
  244. idx_end = min((i + 1) * batch_size, len(user_union_ids))
  245. batch_union_ids = user_union_ids[idx_begin:idx_end]
  246. sql = f"SELECT third_party_user_id, wxid FROM {self.agent_user_table} WHERE wxid IN {str(batch_union_ids)}"
  247. batch_agent_user_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
  248. if len(agent_user_data) != len(batch_union_ids):
  249. # logger.debug(f"staff[{wxid}] some users not found in agent database")
  250. pass
  251. agent_user_data.extend(batch_agent_user_data)
  252. staff_user_pairs = [
  253. {
  254. 'staff_id': agent_staff['third_party_user_id'],
  255. 'user_id': agent_user['third_party_user_id']
  256. }
  257. for agent_user in agent_user_data
  258. ]
  259. ret.extend(staff_user_pairs)
  260. return ret
  261. def get_user_tags(self, user_id: str) -> List[str]:
  262. sql = f"SELECT wxid FROM {self.agent_user_table} WHERE third_party_user_id = '{user_id}' AND wxid is not null"
  263. user_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
  264. if not user_data:
  265. logger.error(f"user[{user_id}] has no wxid")
  266. return []
  267. user_wxid = user_data[0]['wxid']
  268. sql = f"""
  269. select b.tag_id, c.`tag_name` from `we_com_user` as a
  270. join `we_com_user_with_tag` as b
  271. join `we_com_tag` as c
  272. on a.`id` = b.`user_id`
  273. and b.`tag_id` = c.id
  274. where a.union_id = '{user_wxid}' """
  275. tag_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
  276. tag_names = [tag['tag_name'] for tag in tag_data]
  277. return tag_names
  278. if __name__ == '__main__':
  279. config = configs.get()
  280. user_db_config = config['storage']['user']
  281. staff_db_config = config['storage']['staff']
  282. user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
  283. user_profile = user_manager.get_user_profile('7881301263964433')
  284. print(user_profile)
  285. wecom_db_config = config['storage']['user_relation']
  286. user_relation_manager = MySQLUserRelationManager(
  287. user_db_config['mysql'], wecom_db_config['mysql'],
  288. config['storage']['staff']['table'],
  289. user_db_config['table'],
  290. wecom_db_config['table']['staff'],
  291. wecom_db_config['table']['relation'],
  292. wecom_db_config['table']['user']
  293. )
  294. # all_staff_users = user_relation_manager.list_staff_users()
  295. user_tags = user_relation_manager.get_user_tags('7881302078008656')
  296. print(user_tags)