user_manager.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. from pqai_agent.logging_service import logger
  5. from typing import Dict, Optional, List
  6. import json
  7. import time
  8. import os
  9. import abc
  10. import pymysql.cursors
  11. from pqai_agent import configs
  12. from pqai_agent.database import MySQLManager
  13. class UserManager(abc.ABC):
  14. @abc.abstractmethod
  15. def get_user_profile(self, user_id) -> Dict:
  16. pass
  17. @abc.abstractmethod
  18. def save_user_profile(self, user_id, profile: Dict) -> None:
  19. pass
  20. @abc.abstractmethod
  21. def list_all_users(self):
  22. pass
  23. @abc.abstractmethod
  24. def get_staff_profile(self, staff_id) -> Dict:
  25. #FIXME(zhoutian): 重新设计用户和员工数据管理模型
  26. pass
  27. @staticmethod
  28. def get_default_profile(**kwargs) -> Dict:
  29. default_profile = {
  30. "name": "",
  31. "nickname": "",
  32. "avatar": "",
  33. "preferred_nickname": "",
  34. "gender": "未知",
  35. "age": 0,
  36. "region": '',
  37. "interests": [],
  38. "family_members": {},
  39. "health_conditions": [],
  40. "medications": [],
  41. "reminder_preferences": {
  42. "medication": True,
  43. "health": True,
  44. "weather": True,
  45. "news": False
  46. },
  47. "interaction_style": "standard", # standard, verbose, concise
  48. "interaction_frequency": "medium", # low, medium, high
  49. "human_intervention_history": []
  50. }
  51. for key, value in kwargs.items():
  52. if key in default_profile:
  53. default_profile[key] = value
  54. return default_profile
  55. def list_users(self, **kwargs) -> List[Dict]:
  56. pass
  57. class UserRelationManager(abc.ABC):
  58. @abc.abstractmethod
  59. def list_staffs(self):
  60. pass
  61. @abc.abstractmethod
  62. def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
  63. pass
  64. @abc.abstractmethod
  65. def list_staff_users(self, staff_id: str = None, tag_id: int = None) -> List[Dict]:
  66. pass
  67. @abc.abstractmethod
  68. def get_user_tags(self, user_id: str) -> List[str]:
  69. pass
  70. @abc.abstractmethod
  71. def stop_user_daily_push(self, user_id: str) -> bool:
  72. pass
  73. class LocalUserManager(UserManager):
  74. def get_user_profile(self, user_id) -> Dict:
  75. """加载用户个人资料,如不存在则创建默认资料。主要用于本地调试"""
  76. default_profile = self.get_default_profile()
  77. try:
  78. with open(f"user_profiles/{user_id}.json", "r", encoding="utf-8") as f:
  79. profile = json.load(f)
  80. entry_added = False
  81. for key, value in default_profile.items():
  82. if key not in profile:
  83. logger.debug(f"user[{user_id}] add profile key[{key}] value[{value}]")
  84. profile[key] = value
  85. entry_added = True
  86. if entry_added:
  87. self.save_user_profile(user_id, profile)
  88. return profile
  89. except FileNotFoundError:
  90. # 创建默认用户资料
  91. self.save_user_profile(user_id, default_profile)
  92. return default_profile
  93. def save_user_profile(self, user_id, profile: Dict) -> None:
  94. if not user_id:
  95. raise Exception("Invalid user_id: {}".format(user_id))
  96. with open(f"user_profiles/{user_id}.json", "w", encoding="utf-8") as f:
  97. json.dump(profile, f, ensure_ascii=False, indent=2)
  98. def list_all_users(self):
  99. user_ids = []
  100. for root, dirs, files in os.walk('../user_profiles/'):
  101. for file in files:
  102. if file.endswith('.json'):
  103. user_ids.append(os.path.splitext(file)[0])
  104. return user_ids
  105. def get_staff_profile(self, staff_id) -> Dict:
  106. try:
  107. with open(f"user_profiles/{staff_id}.json", "r", encoding="utf-8") as f:
  108. profile = json.load(f)
  109. entry_added = False
  110. if entry_added:
  111. self.save_user_profile(staff_id, profile)
  112. return profile
  113. except Exception as e:
  114. logger.error("staff profile not found: {}".format(e))
  115. return {}
  116. def list_users(self, **kwargs) -> List[Dict]:
  117. pass
  118. class MySQLUserManager(UserManager):
  119. PROFILE_EXCLUDE_ITEMS = ['avatar', ]
  120. def __init__(self, db_config, table_name, staff_table):
  121. self.db = MySQLManager(db_config)
  122. self.table_name = table_name
  123. self.staff_table = staff_table
  124. def get_user_profile(self, user_id) -> Dict:
  125. sql = f"SELECT name, wxid, profile_data_v1, gender, iconurl as avatar" \
  126. f" FROM {self.table_name} WHERE third_party_user_id = {user_id}"
  127. data = self.db.select(sql, pymysql.cursors.DictCursor)
  128. if not data:
  129. logger.error(f"user[{user_id}] not found")
  130. return {}
  131. data = data[0]
  132. gender_map = {0: '未知', 1: '男', 2: '女', None: '未知'}
  133. gender = gender_map[data['gender']]
  134. default_profile = self.get_default_profile(nickname=data['name'], gender=gender, avatar=data['avatar'])
  135. if not data['profile_data_v1']:
  136. logger.warning(f"user[{user_id}] profile not found, create a default one")
  137. self.save_user_profile(user_id, default_profile)
  138. return default_profile
  139. else:
  140. profile = json.loads(data['profile_data_v1'])
  141. # 资料条目有增加时,需合并更新
  142. entry_added = False
  143. for key, value in default_profile.items():
  144. if key not in profile:
  145. # logger.debug(f"user[{user_id}] add profile key[{key}] value[{value}]")
  146. profile[key] = value
  147. entry_added = True
  148. if entry_added:
  149. self.save_user_profile(user_id, profile)
  150. return profile
  151. def save_user_profile(self, user_id, profile: Dict) -> None:
  152. if not user_id:
  153. raise Exception("Invalid user_id: {}".format(user_id))
  154. if configs.get().get('debug_flags', {}).get('disable_database_write', False):
  155. return
  156. profile = profile.copy()
  157. for name in self.PROFILE_EXCLUDE_ITEMS:
  158. profile.pop(name, None)
  159. sql = f"UPDATE {self.table_name} SET profile_data_v1 = %s WHERE third_party_user_id = {user_id}"
  160. self.db.execute(sql, (json.dumps(profile),))
  161. def list_all_users(self):
  162. sql = f"SELECT third_party_user_id FROM {self.table_name}"
  163. data = self.db.select(sql, pymysql.cursors.DictCursor)
  164. return [user['third_party_user_id'] for user in data]
  165. def get_staff_profile(self, staff_id) -> Dict:
  166. if not self.staff_table:
  167. raise Exception("staff_table is not set")
  168. return self.get_staff_profile_v3(staff_id)
  169. def get_staff_profile_v1(self, staff_id) -> Dict:
  170. sql = f"SELECT agent_name, agent_gender, agent_age, agent_region, agent_profile " \
  171. f"FROM {self.staff_table} WHERE third_party_user_id = '{staff_id}'"
  172. data = self.db.select(sql, pymysql.cursors.DictCursor)
  173. if not data:
  174. logger.error(f"staff[{staff_id}] not found")
  175. return {}
  176. profile = data[0]
  177. # 转换性别格式
  178. gender_map = {0: '未知', 1: '男', 2: '女', None: '未知'}
  179. profile['agent_gender'] = gender_map[profile['agent_gender']]
  180. return profile
  181. def get_staff_profile_v2(self, staff_id) -> Dict:
  182. sql = f"SELECT agent_name as name, agent_gender as gender, agent_age as age, agent_region as region, agent_profile " \
  183. f"FROM {self.staff_table} WHERE third_party_user_id = '{staff_id}'"
  184. data = self.db.select(sql, pymysql.cursors.DictCursor)
  185. if not data:
  186. logger.error(f"staff[{staff_id}] not found")
  187. return {}
  188. profile = data[0]
  189. # 转换性别格式
  190. gender_map = {0: '未知', 1: '男', 2: '女', None: '未知'}
  191. profile['gender'] = gender_map[profile['gender']]
  192. # 合并JSON字段(新版本)数据
  193. if profile['agent_profile']:
  194. detail_profile = json.loads(profile['agent_profile'])
  195. profile.update(detail_profile)
  196. # 去除原始字段
  197. profile.pop('agent_profile', None)
  198. return profile
  199. def get_staff_profile_v3(self, staff_id) -> Dict:
  200. sql = f"SELECT agent_profile " \
  201. f"FROM {self.staff_table} WHERE third_party_user_id = '{staff_id}'"
  202. data = self.db.select(sql)
  203. if not data:
  204. logger.error(f"staff[{staff_id}] not found")
  205. return {}
  206. profile_str = data[0][0]
  207. if not profile_str:
  208. return {}
  209. profile = json.loads(profile_str)
  210. return profile
  211. def save_staff_profile(self, staff_id: str, profile: Dict):
  212. # 正常情况下不应该有此操作
  213. if not self.staff_table:
  214. raise Exception("staff_table is not set")
  215. if not staff_id:
  216. raise Exception("Invalid staff_id: {}".format(staff_id))
  217. sql = f"UPDATE {self.staff_table} SET agent_profile = %s WHERE third_party_user_id = '{staff_id}'"
  218. self.db.execute(sql, (json.dumps(profile),))
  219. def list_users(self, **kwargs) -> List[Dict]:
  220. user_union_id = kwargs.get('user_union_id', None)
  221. user_name = kwargs.get('user_name', None)
  222. if not user_union_id and not user_name:
  223. raise Exception("user_union_id or user_name is required")
  224. sql = f"SELECT third_party_user_id, wxid, name, iconurl, gender FROM {self.table_name} WHERE 1=1 "
  225. if user_name:
  226. sql += f"AND name = '{user_name}' COLLATE utf8mb4_bin "
  227. if user_union_id:
  228. sql += f"AND wxid = '{user_union_id}' "
  229. data = self.db.select(sql, pymysql.cursors.DictCursor)
  230. return data
  231. def get_staff_list(self, page_id: int, page_size: int) -> Dict:
  232. """
  233. :param page_size:
  234. :param page_id:
  235. :return:
  236. """
  237. sql = f"""
  238. select t1.third_party_user_id as staff_id, t1.name as staff_name, t2.iconurl as avatar
  239. from qywx_employee t1 left join third_party_user t2
  240. on t1.third_party_user_id = t2.third_party_user_id
  241. limit %s offset %s;
  242. """
  243. staff_list = self.db.select(
  244. sql=sql,
  245. cursor_type=pymysql.cursors.DictCursor,
  246. args=(page_size + 1, page_size * (page_id - 1))
  247. )
  248. if len(staff_list) > page_size:
  249. has_next_page = True
  250. next_page_id = page_id + 1
  251. staff_list = staff_list[:page_size]
  252. else:
  253. has_next_page = False
  254. next_page_id = None
  255. return {
  256. "has_next_page": has_next_page,
  257. "next_page": next_page_id,
  258. "data": staff_list
  259. }
  260. class LocalUserRelationManager(UserRelationManager):
  261. def __init__(self):
  262. pass
  263. def list_staffs(self):
  264. return [
  265. {"third_party_user_id": '1688855931724582', "name": "", "wxid": "ShengHuoLeQu", "agent_name": "小芳"}
  266. ]
  267. def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
  268. return []
  269. def list_staff_users(self, staff_id: str = None, tag_id: int = None):
  270. user_ids = ['7881299453089278', '7881299453132630', '7881299454186909', '7881299455103430', '7881299455173476',
  271. '7881299456216398', '7881299457990953', '7881299461167644', '7881299463002136', '7881299464081604',
  272. '7881299465121735', '7881299465998082', '7881299466221881', '7881299467152300', '7881299470051791',
  273. '7881299470112816', '7881299471149567', '7881299471168030', '7881299471277650', '7881299473321703']
  274. user_ids = user_ids[:5]
  275. return [
  276. {"staff_id": "1688855931724582", "user_id": "7881299670930896"},
  277. *[{"staff_id": "1688855931724582", "user_id": user_id} for user_id in user_ids]
  278. ]
  279. def get_user_tags(self, user_id: str):
  280. return []
  281. def stop_user_daily_push(self, user_id: str) -> bool:
  282. return True
  283. class MySQLUserRelationManager(UserRelationManager):
  284. def __init__(self, agent_db_config, wecom_db_config,
  285. agent_staff_table, agent_user_table,
  286. staff_table, relation_table, user_table):
  287. # FIXME(zhoutian): 因为现在数据库表不统一,需要从两个库读取
  288. self.agent_db = MySQLManager(agent_db_config)
  289. self.wecom_db = MySQLManager(wecom_db_config)
  290. self.agent_staff_table = agent_staff_table
  291. self.staff_table = staff_table
  292. self.relation_table = relation_table
  293. self.agent_user_table = agent_user_table
  294. self.user_table = user_table
  295. def list_staffs(self):
  296. sql = f"SELECT third_party_user_id, name, wxid, agent_name FROM {self.agent_staff_table} WHERE status = 1"
  297. data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
  298. return data
  299. def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
  300. return []
  301. def list_staff_users(self, staff_id: str = None, tag_id: int = None):
  302. sql = f"SELECT third_party_user_id, wxid FROM {self.agent_staff_table} WHERE status = 1"
  303. if staff_id:
  304. sql += f" AND third_party_user_id = '{staff_id}'"
  305. agent_staff_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
  306. if not agent_staff_data:
  307. return []
  308. ret = []
  309. for agent_staff in agent_staff_data:
  310. wxid = agent_staff['wxid']
  311. sql = f"SELECT id FROM {self.staff_table} WHERE carrier_id = '{wxid}'"
  312. staff_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
  313. if not staff_data:
  314. logger.error(f"staff[{wxid}] not found in wecom database")
  315. continue
  316. staff_id = staff_data[0]['id']
  317. sql = f"SELECT user_id FROM {self.relation_table} WHERE staff_id = '{staff_id}' AND is_delete = 0"
  318. user_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
  319. if not user_data:
  320. logger.warning(f"staff[{wxid}] has no user")
  321. continue
  322. user_ids = tuple(user['user_id'] for user in user_data)
  323. sql = f"SELECT union_id FROM {self.user_table} WHERE id IN {str(user_ids)} AND union_id is not null"
  324. if tag_id:
  325. sql += f" AND id in (SELECT distinct user_id FROM we_com_user_with_tag WHERE tag_id = {tag_id} and is_delete = 0)"
  326. user_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
  327. if not user_data:
  328. logger.warning(f"staff[{wxid}] users not found in wecom database")
  329. continue
  330. user_union_ids = tuple(user['union_id'] for user in user_data)
  331. batch_size = 500
  332. n_batches = (len(user_union_ids) + batch_size - 1) // batch_size
  333. agent_user_data = []
  334. for i in range(n_batches):
  335. idx_begin = i * batch_size
  336. idx_end = min((i + 1) * batch_size, len(user_union_ids))
  337. batch_union_ids = user_union_ids[idx_begin:idx_end]
  338. sql = f"SELECT third_party_user_id, wxid FROM {self.agent_user_table} WHERE wxid IN {str(batch_union_ids)}"
  339. batch_agent_user_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
  340. if len(agent_user_data) != len(batch_union_ids):
  341. # logger.debug(f"staff[{wxid}] some users not found in agent database")
  342. pass
  343. agent_user_data.extend(batch_agent_user_data)
  344. staff_user_pairs = [
  345. {
  346. 'staff_id': agent_staff['third_party_user_id'],
  347. 'user_id': agent_user['third_party_user_id']
  348. }
  349. for agent_user in agent_user_data
  350. ]
  351. ret.extend(staff_user_pairs)
  352. return ret
  353. def get_user_union_id(self, user_id: str) -> Optional[str]:
  354. sql = f"SELECT wxid FROM {self.agent_user_table} WHERE third_party_user_id = '{user_id}' AND wxid is not null"
  355. user_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
  356. if not user_data:
  357. logger.error(f"user[{user_id}] has no union id")
  358. return None
  359. union_id = user_data[0]['wxid']
  360. return union_id
  361. def get_user_tags(self, user_id: str) -> List[str]:
  362. union_id = self.get_user_union_id(user_id)
  363. if not union_id:
  364. return []
  365. sql = f"""
  366. select b.tag_id, c.`tag_name` from `we_com_user` as a
  367. join `we_com_user_with_tag` as b
  368. join `we_com_tag` as c
  369. on a.`id` = b.`user_id`
  370. and b.`tag_id` = c.id
  371. where a.union_id = '{union_id}' """
  372. tag_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
  373. tag_names = [tag['tag_name'] for tag in tag_data]
  374. return tag_names
  375. def stop_user_daily_push(self, user_id: str) -> bool:
  376. try:
  377. union_id = self.get_user_union_id(user_id)
  378. if not union_id:
  379. return False
  380. sql = f"UPDATE {self.user_table} SET group_msg_disabled = 1 WHERE union_id = %s"
  381. rows = self.wecom_db.execute(sql, (union_id, ))
  382. if rows > 0:
  383. return True
  384. else:
  385. return False
  386. except Exception as e:
  387. logger.error(f"stop_user_daily_push failed: {e}")
  388. return False
  389. if __name__ == '__main__':
  390. config = configs.get()
  391. user_db_config = config['storage']['user']
  392. staff_db_config = config['storage']['staff']
  393. user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
  394. user_profile = user_manager.get_user_profile('7881301263964433')
  395. print(user_profile)
  396. wecom_db_config = config['storage']['user_relation']
  397. user_relation_manager = MySQLUserRelationManager(
  398. user_db_config['mysql'], wecom_db_config['mysql'],
  399. config['storage']['staff']['table'],
  400. user_db_config['table'],
  401. wecom_db_config['table']['staff'],
  402. wecom_db_config['table']['relation'],
  403. wecom_db_config['table']['user']
  404. )
  405. # all_staff_users = user_relation_manager.list_staff_users()
  406. user_tags = user_relation_manager.get_user_tags('7881302078008656')
  407. print(user_tags)