user_manager.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. from logging_service import logger
  5. from typing import Dict, Optional, Tuple, Any, List
  6. import json
  7. import time
  8. import os
  9. import abc
  10. import pymysql.cursors
  11. import configs
  12. from database import MySQLManager
  13. class UserManager(abc.ABC):
  14. @abc.abstractmethod
  15. def get_user_profile(self, user_id) -> Dict:
  16. pass
  17. @abc.abstractmethod
  18. def save_user_profile(self, user_id, profile: Dict) -> None:
  19. pass
  20. @abc.abstractmethod
  21. def list_all_users(self):
  22. pass
  23. @abc.abstractmethod
  24. def get_staff_profile(self, staff_id) -> Dict:
  25. #FIXME(zhoutian): 重新设计用户和员工数据管理模型
  26. pass
  27. @staticmethod
  28. def get_default_profile(**kwargs) -> Dict:
  29. default_profile = {
  30. "name": "",
  31. "nickname": "",
  32. "avatar": "",
  33. "preferred_nickname": "",
  34. "gender": "未知",
  35. "age": 0,
  36. "region": '',
  37. "interests": [],
  38. "family_members": {},
  39. "health_conditions": [],
  40. "medications": [],
  41. "reminder_preferences": {
  42. "medication": True,
  43. "health": True,
  44. "weather": True,
  45. "news": False
  46. },
  47. "interaction_style": "standard", # standard, verbose, concise
  48. "interaction_frequency": "medium", # low, medium, high
  49. "last_topics": [],
  50. "created_at": int(time.time() * 1000),
  51. "human_intervention_history": []
  52. }
  53. for key, value in kwargs.items():
  54. if key in default_profile:
  55. default_profile[key] = value
  56. return default_profile
  57. def list_users(self, **kwargs) -> List[Dict]:
  58. pass
  59. class UserRelationManager(abc.ABC):
  60. @abc.abstractmethod
  61. def list_staffs(self):
  62. pass
  63. @abc.abstractmethod
  64. def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
  65. pass
  66. @abc.abstractmethod
  67. def list_staff_users(self) -> List[Dict]:
  68. pass
  69. @abc.abstractmethod
  70. def get_user_tags(self, user_id: str) -> List[str]:
  71. pass
  72. class LocalUserManager(UserManager):
  73. def get_user_profile(self, user_id) -> Dict:
  74. """加载用户个人资料,如不存在则创建默认资料。主要用于本地调试"""
  75. default_profile = self.get_default_profile()
  76. try:
  77. with open(f"user_profiles/{user_id}.json", "r", encoding="utf-8") as f:
  78. profile = json.load(f)
  79. entry_added = False
  80. for key, value in default_profile.items():
  81. if key not in profile:
  82. logger.debug(f"user[{user_id}] add profile key[{key}] value[{value}]")
  83. profile[key] = value
  84. entry_added = True
  85. if entry_added:
  86. self.save_user_profile(user_id, profile)
  87. return profile
  88. except FileNotFoundError:
  89. # 创建默认用户资料
  90. self.save_user_profile(user_id, default_profile)
  91. return default_profile
  92. def save_user_profile(self, user_id, profile: Dict) -> None:
  93. if not user_id:
  94. raise Exception("Invalid user_id: {}".format(user_id))
  95. with open(f"user_profiles/{user_id}.json", "w", encoding="utf-8") as f:
  96. json.dump(profile, f, ensure_ascii=False, indent=2)
  97. def list_all_users(self):
  98. user_ids = []
  99. for root, dirs, files in os.walk('user_profiles/'):
  100. for file in files:
  101. if file.endswith('.json'):
  102. user_ids.append(os.path.splitext(file)[0])
  103. return user_ids
  104. def get_staff_profile(self, staff_id) -> Dict:
  105. # for test only
  106. return {
  107. 'agent_name': '小芳',
  108. 'agent_gender': '女',
  109. 'agent_age': 30,
  110. 'agent_region': '北京'
  111. }
  112. def list_users(self, **kwargs) -> List[Dict]:
  113. pass
  114. class MySQLUserManager(UserManager):
  115. PROFILE_EXCLUDE_ITEMS = ['avatar', ]
  116. def __init__(self, db_config, table_name, staff_table):
  117. self.db = MySQLManager(db_config)
  118. self.table_name = table_name
  119. self.staff_table = staff_table
  120. def get_user_profile(self, user_id) -> Dict:
  121. sql = f"SELECT name, wxid, profile_data_v1, gender, iconurl as avatar" \
  122. f" FROM {self.table_name} WHERE third_party_user_id = {user_id}"
  123. data = self.db.select(sql, pymysql.cursors.DictCursor)
  124. if not data:
  125. logger.error(f"user[{user_id}] not found")
  126. return {}
  127. data = data[0]
  128. gender_map = {0: '未知', 1: '男', 2: '女', None: '未知'}
  129. gender = gender_map[data['gender']]
  130. default_profile = self.get_default_profile(nickname=data['name'], gender=gender, avatar=data['avatar'])
  131. if not data['profile_data_v1']:
  132. logger.warning(f"user[{user_id}] profile not found, create a default one")
  133. self.save_user_profile(user_id, default_profile)
  134. return default_profile
  135. else:
  136. profile = json.loads(data['profile_data_v1'])
  137. # 资料条目有增加时,需合并更新
  138. entry_added = False
  139. for key, value in default_profile.items():
  140. if key not in profile:
  141. # logger.debug(f"user[{user_id}] add profile key[{key}] value[{value}]")
  142. profile[key] = value
  143. entry_added = True
  144. if entry_added:
  145. self.save_user_profile(user_id, profile)
  146. return profile
  147. def save_user_profile(self, user_id, profile: Dict) -> None:
  148. if not user_id:
  149. raise Exception("Invalid user_id: {}".format(user_id))
  150. if configs.get().get('debug_flags', {}).get('disable_database_write', False):
  151. return
  152. profile = profile.copy()
  153. for name in self.PROFILE_EXCLUDE_ITEMS:
  154. profile.pop(name, None)
  155. sql = f"UPDATE {self.table_name} SET profile_data_v1 = %s WHERE third_party_user_id = {user_id}"
  156. self.db.execute(sql, (json.dumps(profile),))
  157. def list_all_users(self):
  158. sql = f"SELECT third_party_user_id FROM {self.table_name}"
  159. data = self.db.select(sql, pymysql.cursors.DictCursor)
  160. return [user['third_party_user_id'] for user in data]
  161. def get_staff_profile(self, staff_id) -> Dict:
  162. if not self.staff_table:
  163. raise Exception("staff_table is not set")
  164. sql = f"SELECT agent_name, agent_gender, agent_age, agent_region " \
  165. f"FROM {self.staff_table} WHERE third_party_user_id = '{staff_id}'"
  166. data = self.db.select(sql, pymysql.cursors.DictCursor)
  167. if not data:
  168. logger.error(f"staff[{staff_id}] not found")
  169. return {}
  170. profile = data[0]
  171. # 转换性别格式
  172. gender_map = {0: '未知', 1: '男', 2: '女', None: '未知'}
  173. profile['agent_gender'] = gender_map[profile['agent_gender']]
  174. return profile
  175. def list_users(self, **kwargs) -> List[Dict]:
  176. user_union_id = kwargs.get('user_union_id', None)
  177. user_name = kwargs.get('user_name', None)
  178. if not user_union_id and not user_name:
  179. raise Exception("user_union_id or user_name is required")
  180. sql = f"SELECT third_party_user_id, wxid, name, iconurl, gender FROM {self.table_name} WHERE 1=1 "
  181. if user_name:
  182. sql += f"AND name = '{user_name}' COLLATE utf8mb4_bin "
  183. if user_union_id:
  184. sql += f"AND wxid = '{user_union_id}' "
  185. data = self.db.select(sql, pymysql.cursors.DictCursor)
  186. return data
  187. class LocalUserRelationManager(UserRelationManager):
  188. def __init__(self):
  189. pass
  190. def list_staffs(self):
  191. return [
  192. {"third_party_user_id": 0, "name": "x", "wxid": "x", "agent_name": "小芳"}
  193. ]
  194. def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
  195. return []
  196. def list_staff_users(self):
  197. return [
  198. {"staff_id": "1688854492669990", "user_id": "7881299670930896"}
  199. ]
  200. def get_user_tags(self, user_id: str):
  201. return []
  202. class MySQLUserRelationManager(UserRelationManager):
  203. def __init__(self, agent_db_config, wecom_db_config,
  204. agent_staff_table, agent_user_table,
  205. staff_table, relation_table, user_table):
  206. # FIXME(zhoutian): 因为现在数据库表不统一,需要从两个库读取
  207. self.agent_db = MySQLManager(agent_db_config)
  208. self.wecom_db = MySQLManager(wecom_db_config)
  209. self.agent_staff_table = agent_staff_table
  210. self.staff_table = staff_table
  211. self.relation_table = relation_table
  212. self.agent_user_table = agent_user_table
  213. self.user_table = user_table
  214. def list_staffs(self):
  215. sql = f"SELECT third_party_user_id, name, wxid, agent_name FROM {self.agent_staff_table} WHERE status = 1"
  216. data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
  217. return data
  218. def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
  219. return []
  220. def list_staff_users(self, staff_id: str = None, tag_id: int = None):
  221. sql = f"SELECT third_party_user_id, wxid FROM {self.agent_staff_table} WHERE status = 1"
  222. if staff_id:
  223. sql += f" AND third_party_user_id = '{staff_id}'"
  224. agent_staff_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
  225. if not agent_staff_data:
  226. return []
  227. ret = []
  228. for agent_staff in agent_staff_data:
  229. wxid = agent_staff['wxid']
  230. sql = f"SELECT id FROM {self.staff_table} WHERE carrier_id = '{wxid}'"
  231. staff_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
  232. if not staff_data:
  233. logger.error(f"staff[{wxid}] not found in wecom database")
  234. continue
  235. staff_id = staff_data[0]['id']
  236. sql = f"SELECT user_id FROM {self.relation_table} WHERE staff_id = '{staff_id}' AND is_delete = 0"
  237. user_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
  238. if not user_data:
  239. logger.warning(f"staff[{wxid}] has no user")
  240. continue
  241. user_ids = tuple(user['user_id'] for user in user_data)
  242. sql = f"SELECT union_id FROM {self.user_table} WHERE id IN {str(user_ids)} AND union_id is not null"
  243. if tag_id:
  244. sql += f" AND id in (SELECT distinct user_id FROM we_com_user_with_tag WHERE tag_id = {tag_id} and is_delete = 0)"
  245. user_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
  246. if not user_data:
  247. logger.warning(f"staff[{wxid}] users not found in wecom database")
  248. continue
  249. user_union_ids = tuple(user['union_id'] for user in user_data)
  250. batch_size = 500
  251. n_batches = (len(user_union_ids) + batch_size - 1) // batch_size
  252. agent_user_data = []
  253. for i in range(n_batches):
  254. idx_begin = i * batch_size
  255. idx_end = min((i + 1) * batch_size, len(user_union_ids))
  256. batch_union_ids = user_union_ids[idx_begin:idx_end]
  257. sql = f"SELECT third_party_user_id, wxid FROM {self.agent_user_table} WHERE wxid IN {str(batch_union_ids)}"
  258. batch_agent_user_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
  259. if len(agent_user_data) != len(batch_union_ids):
  260. # logger.debug(f"staff[{wxid}] some users not found in agent database")
  261. pass
  262. agent_user_data.extend(batch_agent_user_data)
  263. staff_user_pairs = [
  264. {
  265. 'staff_id': agent_staff['third_party_user_id'],
  266. 'user_id': agent_user['third_party_user_id']
  267. }
  268. for agent_user in agent_user_data
  269. ]
  270. ret.extend(staff_user_pairs)
  271. return ret
  272. def get_user_tags(self, user_id: str) -> List[str]:
  273. sql = f"SELECT wxid FROM {self.agent_user_table} WHERE third_party_user_id = '{user_id}' AND wxid is not null"
  274. user_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
  275. if not user_data:
  276. logger.error(f"user[{user_id}] has no wxid")
  277. return []
  278. user_wxid = user_data[0]['wxid']
  279. sql = f"""
  280. select b.tag_id, c.`tag_name` from `we_com_user` as a
  281. join `we_com_user_with_tag` as b
  282. join `we_com_tag` as c
  283. on a.`id` = b.`user_id`
  284. and b.`tag_id` = c.id
  285. where a.union_id = '{user_wxid}' """
  286. tag_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
  287. tag_names = [tag['tag_name'] for tag in tag_data]
  288. return tag_names
  289. if __name__ == '__main__':
  290. config = configs.get()
  291. user_db_config = config['storage']['user']
  292. staff_db_config = config['storage']['staff']
  293. user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
  294. user_profile = user_manager.get_user_profile('7881301263964433')
  295. print(user_profile)
  296. wecom_db_config = config['storage']['user_relation']
  297. user_relation_manager = MySQLUserRelationManager(
  298. user_db_config['mysql'], wecom_db_config['mysql'],
  299. config['storage']['staff']['table'],
  300. user_db_config['table'],
  301. wecom_db_config['table']['staff'],
  302. wecom_db_config['table']['relation'],
  303. wecom_db_config['table']['user']
  304. )
  305. # all_staff_users = user_relation_manager.list_staff_users()
  306. user_tags = user_relation_manager.get_user_tags('7881302078008656')
  307. print(user_tags)