user_manager.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. import logging
  5. from typing import Dict, Optional, Tuple, Any, List
  6. import json
  7. import time
  8. import os
  9. import abc
  10. import pymysql.cursors
  11. import configs
  12. from database import MySQLManager
  13. class UserManager(abc.ABC):
  14. @abc.abstractmethod
  15. def get_user_profile(self, user_id) -> Dict:
  16. pass
  17. @abc.abstractmethod
  18. def save_user_profile(self, user_id, profile: Dict) -> None:
  19. pass
  20. @abc.abstractmethod
  21. def list_all_users(self):
  22. pass
  23. @staticmethod
  24. def get_default_profile(**kwargs) -> Dict:
  25. default_profile = {
  26. "name": "",
  27. "nickname": "",
  28. "preferred_nickname": "",
  29. "age": 0,
  30. "region": '',
  31. "interests": [],
  32. "family_members": {},
  33. "health_conditions": [],
  34. "medications": [],
  35. "reminder_preferences": {
  36. "medication": True,
  37. "health": True,
  38. "weather": True,
  39. "news": False
  40. },
  41. "interaction_style": "standard", # standard, verbose, concise
  42. "interaction_frequency": "medium", # low, medium, high
  43. "last_topics": [],
  44. "created_at": int(time.time() * 1000),
  45. "human_intervention_history": []
  46. }
  47. for key, value in kwargs.items():
  48. if key in default_profile:
  49. default_profile[key] = value
  50. return default_profile
  51. class UserRelationManager(abc.ABC):
  52. @abc.abstractmethod
  53. def list_staffs(self):
  54. pass
  55. @abc.abstractmethod
  56. def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
  57. pass
  58. @abc.abstractmethod
  59. def list_staff_users(self) -> List[Dict]:
  60. pass
  61. class LocalUserManager(UserManager):
  62. def get_user_profile(self, user_id) -> Dict:
  63. """加载用户个人资料,如不存在则创建默认资料。主要用于本地调试"""
  64. try:
  65. with open(f"user_profiles/{user_id}.json", "r", encoding="utf-8") as f:
  66. return json.load(f)
  67. except FileNotFoundError:
  68. # 创建默认用户资料
  69. default_profile = self.get_default_profile()
  70. self.save_user_profile(user_id, default_profile)
  71. return default_profile
  72. def save_user_profile(self, user_id, profile: Dict) -> None:
  73. if not user_id:
  74. raise Exception("Invalid user_id: {}".format(user_id))
  75. with open(f"user_profiles/{user_id}.json", "w", encoding="utf-8") as f:
  76. json.dump(profile, f, ensure_ascii=False, indent=2)
  77. def list_all_users(self):
  78. user_ids = []
  79. for root, dirs, files in os.walk('user_profiles/'):
  80. for file in files:
  81. if file.endswith('.json'):
  82. user_ids.append(os.path.splitext(file)[0])
  83. return user_ids
  84. class MySQLUserManager(UserManager):
  85. def __init__(self, db_config, table_name):
  86. self.db = MySQLManager(db_config)
  87. self.table_name = table_name
  88. def get_user_profile(self, user_id) -> Dict:
  89. sql = f"SELECT name, wxid, profile_data_v1 FROM {self.table_name} WHERE third_party_user_id = {user_id}"
  90. data = self.db.select(sql, pymysql.cursors.DictCursor)
  91. if not data:
  92. logging.error(f"user[{user_id}] not found")
  93. return {}
  94. data = data[0]
  95. if not data['profile_data_v1']:
  96. logging.warning(f"user[{user_id}] profile not found, create a default one")
  97. default_profile = self.get_default_profile(nickname=data['name'])
  98. self.save_user_profile(user_id, default_profile)
  99. return json.loads(data['profile_data_v1'])
  100. def save_user_profile(self, user_id, profile: Dict) -> None:
  101. if not user_id:
  102. raise Exception("Invalid user_id: {}".format(user_id))
  103. sql = f"UPDATE {self.table_name} SET profile_data_v1 = %s WHERE third_party_user_id = {user_id}"
  104. self.db.execute(sql, (json.dumps(profile),))
  105. def list_all_users(self):
  106. sql = f"SELECT third_party_user_id FROM {self.table_name}"
  107. data = self.db.select(sql, pymysql.cursors.DictCursor)
  108. return [user['third_party_user_id'] for user in data]
  109. class MySQLUserRelationManager(UserRelationManager):
  110. def __init__(self, agent_db_config, wecom_db_config,
  111. agent_staff_table, agent_user_table,
  112. staff_table, relation_table, user_table):
  113. # FIXME(zhoutian): 因为现在数据库表不统一,需要从两个库读取
  114. self.agent_db = MySQLManager(agent_db_config)
  115. self.wecom_db = MySQLManager(wecom_db_config)
  116. self.agent_staff_table = agent_staff_table
  117. self.staff_table = staff_table
  118. self.relation_table = relation_table
  119. self.agent_user_table = agent_user_table
  120. self.user_table = user_table
  121. def list_staffs(self):
  122. return []
  123. def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
  124. return []
  125. def list_staff_users(self):
  126. sql = f"SELECT third_party_user_id, wxid FROM {self.agent_staff_table} WHERE status = 1"
  127. agent_staff_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
  128. if not agent_staff_data:
  129. return []
  130. ret = []
  131. for agent_staff in agent_staff_data:
  132. wxid = agent_staff['wxid']
  133. sql = f"SELECT id FROM {self.staff_table} WHERE carrier_id = '{wxid}'"
  134. staff_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
  135. if not staff_data:
  136. logging.error(f"staff[{wxid}] not found in wecom database")
  137. continue
  138. staff_id = staff_data[0]['id']
  139. sql = f"SELECT user_id FROM {self.relation_table} WHERE staff_id = '{staff_id}' AND is_delete = 0"
  140. user_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
  141. if not user_data:
  142. logging.warning(f"staff[{wxid}] has no user")
  143. continue
  144. user_ids = tuple(user['user_id'] for user in user_data)
  145. sql = f"SELECT union_id FROM {self.user_table} WHERE id IN {str(user_ids)}"
  146. user_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
  147. if not user_data:
  148. logging.error(f"staff[{wxid}] users not found in wecom database")
  149. continue
  150. user_union_ids = tuple(user['union_id'] for user in user_data)
  151. batch_size = 100
  152. n_batches = (len(user_union_ids) + batch_size - 1) // batch_size
  153. agent_user_data = []
  154. for i in range(n_batches):
  155. idx_begin = i * batch_size
  156. idx_end = min((i + 1) * batch_size, len(user_union_ids))
  157. batch_union_ids = user_union_ids[idx_begin:idx_end]
  158. sql = f"SELECT third_party_user_id, wxid FROM {self.agent_user_table} WHERE wxid IN {str(batch_union_ids)}"
  159. batch_agent_user_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
  160. if len(agent_user_data) != len(batch_union_ids):
  161. logging.error(f"staff[{wxid}] some users not found in agent database")
  162. agent_user_data.extend(batch_agent_user_data)
  163. staff_user_pairs = [
  164. {
  165. 'staff_id': agent_staff['third_party_user_id'],
  166. 'user_id': agent_user['third_party_user_id']
  167. }
  168. for agent_user in agent_user_data
  169. ]
  170. ret.extend(staff_user_pairs)
  171. return ret
  172. if __name__ == '__main__':
  173. config = configs.get()
  174. user_db_config = config['storage']['user']
  175. user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'])
  176. user_profile = user_manager.get_user_profile('7881301263964433')
  177. print(user_profile)
  178. wecom_db_config = config['storage']['user_relation']
  179. user_relation_manager = MySQLUserRelationManager(
  180. user_db_config['mysql'], wecom_db_config['mysql'],
  181. config['storage']['staff']['table'],
  182. user_db_config['table'],
  183. wecom_db_config['table']['staff'],
  184. wecom_db_config['table']['relation'],
  185. wecom_db_config['table']['user']
  186. )
  187. all_staff_users = user_relation_manager.list_staff_users()
  188. print(all_staff_users)