user_manager.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. from abc import abstractmethod
  5. from pqai_agent.logging_service import logger
  6. from typing import Dict, Optional, List
  7. import json
  8. import time
  9. import os
  10. import abc
  11. import pymysql.cursors
  12. from pqai_agent import configs
  13. from pqai_agent.database import MySQLManager
  14. class UserManager(abc.ABC):
  15. @abc.abstractmethod
  16. def get_user_profile(self, user_id) -> Dict:
  17. pass
  18. @abc.abstractmethod
  19. def save_user_profile(self, user_id, profile: Dict) -> None:
  20. pass
  21. @abc.abstractmethod
  22. def list_all_users(self):
  23. pass
  24. @abc.abstractmethod
  25. def get_staff_profile(self, staff_id) -> Dict:
  26. #FIXME(zhoutian): 重新设计用户和员工数据管理模型
  27. pass
  28. @abstractmethod
  29. def get_user_tags(self, user_ids: List[str], batch_size = 500) -> Dict[str, List[str]]:
  30. pass
  31. @staticmethod
  32. def get_default_profile(**kwargs) -> Dict:
  33. default_profile = {
  34. "name": "",
  35. "nickname": "",
  36. "avatar": "",
  37. "preferred_nickname": "",
  38. "gender": "未知",
  39. "age": 0,
  40. "region": '',
  41. "interests": [],
  42. "family_members": {},
  43. "health_conditions": [],
  44. "medications": [],
  45. "reminder_preferences": {
  46. "medication": True,
  47. "health": True,
  48. "weather": True,
  49. "news": False
  50. },
  51. "interaction_style": "standard", # standard, verbose, concise
  52. "interaction_frequency": "medium", # low, medium, high
  53. "human_intervention_history": []
  54. }
  55. for key, value in kwargs.items():
  56. if key in default_profile:
  57. default_profile[key] = value
  58. return default_profile
  59. def list_users(self, **kwargs) -> List[Dict]:
  60. pass
  61. class UserRelationManager(abc.ABC):
  62. @abc.abstractmethod
  63. def list_staffs(self):
  64. pass
  65. @abc.abstractmethod
  66. def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
  67. pass
  68. @abc.abstractmethod
  69. def list_staff_users(self, staff_id: str = None, tag_id: int = None) -> List[Dict]:
  70. pass
  71. @abc.abstractmethod
  72. def get_user_tags(self, user_id: str) -> List[str]:
  73. pass
  74. @abc.abstractmethod
  75. def stop_user_daily_push(self, user_id: str) -> bool:
  76. pass
  77. class LocalUserManager(UserManager):
  78. def get_user_profile(self, user_id) -> Dict:
  79. """加载用户个人资料,如不存在则创建默认资料。主要用于本地调试"""
  80. default_profile = self.get_default_profile()
  81. try:
  82. with open(f"user_profiles/{user_id}.json", "r", encoding="utf-8") as f:
  83. profile = json.load(f)
  84. entry_added = False
  85. for key, value in default_profile.items():
  86. if key not in profile:
  87. logger.debug(f"user[{user_id}] add profile key[{key}] value[{value}]")
  88. profile[key] = value
  89. entry_added = True
  90. if entry_added:
  91. self.save_user_profile(user_id, profile)
  92. return profile
  93. except FileNotFoundError:
  94. # 创建默认用户资料
  95. self.save_user_profile(user_id, default_profile)
  96. return default_profile
  97. def save_user_profile(self, user_id, profile: Dict) -> None:
  98. if not user_id:
  99. raise Exception("Invalid user_id: {}".format(user_id))
  100. with open(f"user_profiles/{user_id}.json", "w", encoding="utf-8") as f:
  101. json.dump(profile, f, ensure_ascii=False, indent=2)
  102. def list_all_users(self):
  103. user_ids = []
  104. for root, dirs, files in os.walk('../user_profiles/'):
  105. for file in files:
  106. if file.endswith('.json'):
  107. user_ids.append(os.path.splitext(file)[0])
  108. return user_ids
  109. def get_staff_profile(self, staff_id) -> Dict:
  110. try:
  111. with open(f"user_profiles/{staff_id}.json", "r", encoding="utf-8") as f:
  112. profile = json.load(f)
  113. entry_added = False
  114. if entry_added:
  115. self.save_user_profile(staff_id, profile)
  116. return profile
  117. except Exception as e:
  118. logger.error("staff profile not found: {}".format(e))
  119. return {}
  120. def get_user_tags(self, user_ids: List[str], batch_size = 500) -> Dict[str, List[str]]:
  121. return {}
  122. def list_users(self, **kwargs) -> List[Dict]:
  123. pass
  124. class MySQLUserManager(UserManager):
  125. PROFILE_EXCLUDE_ITEMS = ['avatar', ]
  126. def __init__(self, db_config, table_name, staff_table):
  127. self.db = MySQLManager(db_config)
  128. self.table_name = table_name
  129. self.staff_table = staff_table
  130. def get_user_profile(self, user_id) -> Dict:
  131. sql = f"SELECT name, wxid, profile_data_v1, gender, iconurl as avatar" \
  132. f" FROM {self.table_name} WHERE third_party_user_id = {user_id}"
  133. data = self.db.select(sql, pymysql.cursors.DictCursor)
  134. if not data:
  135. logger.error(f"user[{user_id}] not found")
  136. return {}
  137. data = data[0]
  138. gender_map = {0: '未知', 1: '男', 2: '女', None: '未知'}
  139. gender = gender_map[data['gender']]
  140. default_profile = self.get_default_profile(nickname=data['name'], gender=gender, avatar=data['avatar'])
  141. if not data['profile_data_v1']:
  142. logger.warning(f"user[{user_id}] profile not found, create a default one")
  143. self.save_user_profile(user_id, default_profile)
  144. return default_profile
  145. else:
  146. profile = json.loads(data['profile_data_v1'])
  147. # 资料条目有增加时,需合并更新
  148. entry_added = False
  149. for key, value in default_profile.items():
  150. if key not in profile:
  151. # logger.debug(f"user[{user_id}] add profile key[{key}] value[{value}]")
  152. profile[key] = value
  153. entry_added = True
  154. if entry_added:
  155. self.save_user_profile(user_id, profile)
  156. return profile
  157. def save_user_profile(self, user_id, profile: Dict) -> None:
  158. if not user_id:
  159. raise Exception("Invalid user_id: {}".format(user_id))
  160. if configs.get().get('debug_flags', {}).get('disable_database_write', False):
  161. return
  162. profile = profile.copy()
  163. for name in self.PROFILE_EXCLUDE_ITEMS:
  164. profile.pop(name, None)
  165. sql = f"UPDATE {self.table_name} SET profile_data_v1 = %s WHERE third_party_user_id = {user_id}"
  166. self.db.execute(sql, (json.dumps(profile),))
  167. def list_all_users(self):
  168. sql = f"SELECT third_party_user_id FROM {self.table_name}"
  169. data = self.db.select(sql, pymysql.cursors.DictCursor)
  170. return [user['third_party_user_id'] for user in data]
  171. def get_staff_profile(self, staff_id) -> Dict:
  172. if not self.staff_table:
  173. raise Exception("staff_table is not set")
  174. return self.get_staff_profile_v3(staff_id)
  175. def get_staff_profile_v1(self, staff_id) -> Dict:
  176. sql = f"SELECT agent_name, agent_gender, agent_age, agent_region, agent_profile " \
  177. f"FROM {self.staff_table} WHERE third_party_user_id = '{staff_id}'"
  178. data = self.db.select(sql, pymysql.cursors.DictCursor)
  179. if not data:
  180. logger.error(f"staff[{staff_id}] not found")
  181. return {}
  182. profile = data[0]
  183. # 转换性别格式
  184. gender_map = {0: '未知', 1: '男', 2: '女', None: '未知'}
  185. profile['agent_gender'] = gender_map[profile['agent_gender']]
  186. return profile
  187. def get_staff_profile_v2(self, staff_id) -> Dict:
  188. sql = f"SELECT agent_name as name, agent_gender as gender, agent_age as age, agent_region as region, agent_profile " \
  189. f"FROM {self.staff_table} WHERE third_party_user_id = '{staff_id}'"
  190. data = self.db.select(sql, pymysql.cursors.DictCursor)
  191. if not data:
  192. logger.error(f"staff[{staff_id}] not found")
  193. return {}
  194. profile = data[0]
  195. # 转换性别格式
  196. gender_map = {0: '未知', 1: '男', 2: '女', None: '未知'}
  197. profile['gender'] = gender_map[profile['gender']]
  198. # 合并JSON字段(新版本)数据
  199. if profile['agent_profile']:
  200. detail_profile = json.loads(profile['agent_profile'])
  201. profile.update(detail_profile)
  202. # 去除原始字段
  203. profile.pop('agent_profile', None)
  204. return profile
  205. def get_staff_profile_v3(self, staff_id) -> Dict:
  206. sql = f"SELECT agent_profile_v2 " \
  207. f"FROM {self.staff_table} WHERE third_party_user_id = '{staff_id}'"
  208. data = self.db.select(sql)
  209. if not data:
  210. logger.error(f"staff[{staff_id}] not found")
  211. return {}
  212. profile_str = data[0][0]
  213. if not profile_str:
  214. return {}
  215. profile = json.loads(profile_str)
  216. return profile
  217. def save_staff_profile(self, staff_id: str, profile: Dict):
  218. # 正常情况下不应该有此操作
  219. if not self.staff_table:
  220. raise Exception("staff_table is not set")
  221. if not staff_id:
  222. raise Exception("Invalid staff_id: {}".format(staff_id))
  223. sql = f"UPDATE {self.staff_table} SET agent_profile = %s WHERE third_party_user_id = '{staff_id}'"
  224. self.db.execute(sql, (json.dumps(profile),))
  225. def get_user_tags(self, user_ids: List[str], batch_size = 500) -> Dict[str, List[str]]:
  226. """
  227. 获取用户的标签列表
  228. :param user_ids: 用户ID
  229. :param batch_size: 批量查询的大小
  230. :return: 标签名称列表
  231. """
  232. batches = (len(user_ids) + batch_size - 1) // batch_size
  233. ret = {}
  234. for i in range(batches):
  235. idx_begin = i * batch_size
  236. idx_end = min((i + 1) * batch_size, len(user_ids))
  237. batch_user_ids = user_ids[idx_begin:idx_end]
  238. sql = f"""
  239. SELECT a.third_party_user_id, b.tag_id, b.name FROM qywx_user_tag a
  240. JOIN qywx_tag b ON a.tag_id = b.tag_id
  241. AND a.third_party_user_id IN {str(tuple(batch_user_ids))}
  242. """
  243. rows = self.db.select(sql, pymysql.cursors.DictCursor)
  244. # group by user_id
  245. for row in rows:
  246. user_id = row['third_party_user_id']
  247. tag_name = row['name']
  248. if user_id not in ret:
  249. ret[user_id] = []
  250. ret[user_id].append(tag_name)
  251. return ret
  252. def list_users(self, **kwargs) -> List[Dict]:
  253. user_union_id = kwargs.get('user_union_id', None)
  254. user_name = kwargs.get('user_name', None)
  255. if not user_union_id and not user_name:
  256. raise Exception("user_union_id or user_name is required")
  257. sql = f"SELECT third_party_user_id, wxid, name, iconurl, gender FROM {self.table_name} WHERE 1=1 "
  258. if user_name:
  259. sql += f"AND name = '{user_name}' COLLATE utf8mb4_bin "
  260. if user_union_id:
  261. sql += f"AND wxid = '{user_union_id}' "
  262. data = self.db.select(sql, pymysql.cursors.DictCursor)
  263. return data
  264. def get_staff_list(self, page_id: int, page_size: int) -> Dict:
  265. """
  266. :param page_size:
  267. :param page_id:
  268. :return:
  269. """
  270. sql = f"""
  271. select t1.third_party_user_id as staff_id, t1.name as staff_name, t2.iconurl as avatar
  272. from qywx_employee t1 left join third_party_user t2
  273. on t1.third_party_user_id = t2.third_party_user_id
  274. limit %s offset %s;
  275. """
  276. staff_list = self.db.select(
  277. sql=sql,
  278. cursor_type=pymysql.cursors.DictCursor,
  279. args=(page_size + 1, page_size * (page_id - 1))
  280. )
  281. if len(staff_list) > page_size:
  282. has_next_page = True
  283. next_page_id = page_id + 1
  284. staff_list = staff_list[:page_size]
  285. else:
  286. has_next_page = False
  287. next_page_id = None
  288. return {
  289. "has_next_page": has_next_page,
  290. "next_page": next_page_id,
  291. "data": staff_list
  292. }
  293. class LocalUserRelationManager(UserRelationManager):
  294. def __init__(self):
  295. pass
  296. def list_staffs(self):
  297. return [
  298. {"third_party_user_id": '1688855931724582', "name": "", "wxid": "ShengHuoLeQu", "agent_name": "小芳"}
  299. ]
  300. def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
  301. return []
  302. def list_staff_users(self, staff_id: str = None, tag_id: int = None):
  303. user_ids = ['7881299453089278', '7881299453132630', '7881299454186909', '7881299455103430', '7881299455173476',
  304. '7881299456216398', '7881299457990953', '7881299461167644', '7881299463002136', '7881299464081604',
  305. '7881299465121735', '7881299465998082', '7881299466221881', '7881299467152300', '7881299470051791',
  306. '7881299470112816', '7881299471149567', '7881299471168030', '7881299471277650', '7881299473321703']
  307. user_ids = user_ids[:5]
  308. return [
  309. {"staff_id": "1688855931724582", "user_id": "7881299670930896"},
  310. *[{"staff_id": "1688855931724582", "user_id": user_id} for user_id in user_ids]
  311. ]
  312. def get_user_tags(self, user_id: str):
  313. return []
  314. def stop_user_daily_push(self, user_id: str) -> bool:
  315. return True
  316. class MySQLUserRelationManager(UserRelationManager):
  317. def __init__(self, agent_db_config, wecom_db_config,
  318. agent_staff_table, agent_user_table,
  319. staff_table, relation_table, user_table):
  320. # FIXME(zhoutian): 因为现在数据库表不统一,需要从两个库读取
  321. self.agent_db = MySQLManager(agent_db_config)
  322. self.wecom_db = MySQLManager(wecom_db_config)
  323. self.agent_staff_table = agent_staff_table
  324. self.staff_table = staff_table
  325. self.relation_table = relation_table
  326. self.agent_user_table = agent_user_table
  327. self.user_table = user_table
  328. self.agent_user_relation_table = 'qywx_employee_customer'
  329. def list_staffs(self):
  330. sql = f"SELECT third_party_user_id, name, wxid, agent_name FROM {self.agent_staff_table} WHERE status = 1"
  331. data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
  332. return data
  333. def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
  334. return []
  335. def list_staff_users(self, staff_id: str = None, tag_id: int = None) -> List[Dict]:
  336. sql = f"SELECT employee_id as staff_id, customer_id as user_id FROM {self.agent_user_relation_table} WHERE 1 = 1"
  337. if staff_id:
  338. sql += f" AND employee_id = '{staff_id}'"
  339. agent_staff_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
  340. return agent_staff_data
  341. def list_staff_users_v1(self, staff_id: str = None, tag_id: int = None):
  342. sql = f"SELECT third_party_user_id, wxid FROM {self.agent_staff_table} WHERE status = 1"
  343. if staff_id:
  344. sql += f" AND third_party_user_id = '{staff_id}'"
  345. agent_staff_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
  346. if not agent_staff_data:
  347. return []
  348. ret = []
  349. for agent_staff in agent_staff_data:
  350. wxid = agent_staff['wxid']
  351. sql = f"SELECT id FROM {self.staff_table} WHERE carrier_id = '{wxid}'"
  352. staff_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
  353. if not staff_data:
  354. logger.error(f"staff[{wxid}] not found in wecom database")
  355. continue
  356. staff_id = staff_data[0]['id']
  357. sql = f"SELECT user_id FROM {self.relation_table} WHERE staff_id = '{staff_id}' AND is_delete = 0"
  358. user_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
  359. if not user_data:
  360. logger.warning(f"staff[{wxid}] has no user")
  361. continue
  362. user_ids = tuple(user['user_id'] for user in user_data)
  363. sql = f"SELECT union_id FROM {self.user_table} WHERE id IN {str(user_ids)} AND union_id is not null"
  364. if tag_id:
  365. sql += f" AND id in (SELECT distinct user_id FROM we_com_user_with_tag WHERE tag_id = {tag_id} and is_delete = 0)"
  366. user_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
  367. if not user_data:
  368. logger.warning(f"staff[{wxid}] users not found in wecom database")
  369. continue
  370. user_union_ids = tuple(user['union_id'] for user in user_data)
  371. batch_size = 500
  372. n_batches = (len(user_union_ids) + batch_size - 1) // batch_size
  373. agent_user_data = []
  374. for i in range(n_batches):
  375. idx_begin = i * batch_size
  376. idx_end = min((i + 1) * batch_size, len(user_union_ids))
  377. batch_union_ids = user_union_ids[idx_begin:idx_end]
  378. sql = f"SELECT third_party_user_id, wxid FROM {self.agent_user_table} WHERE wxid IN {str(batch_union_ids)}"
  379. batch_agent_user_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
  380. if len(agent_user_data) != len(batch_union_ids):
  381. diff_num = len(batch_union_ids) - len(batch_agent_user_data)
  382. logger.debug(f"staff[{staff_id}] {diff_num} users not found in agent database")
  383. pass
  384. agent_user_data.extend(batch_agent_user_data)
  385. staff_user_pairs = [
  386. {
  387. 'staff_id': agent_staff['third_party_user_id'],
  388. 'user_id': agent_user['third_party_user_id']
  389. }
  390. for agent_user in agent_user_data
  391. ]
  392. ret.extend(staff_user_pairs)
  393. return ret
  394. def get_user_union_id(self, user_id: str) -> Optional[str]:
  395. sql = f"SELECT wxid FROM {self.agent_user_table} WHERE third_party_user_id = '{user_id}' AND wxid is not null"
  396. user_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
  397. if not user_data:
  398. logger.error(f"user[{user_id}] has no union id")
  399. return None
  400. union_id = user_data[0]['wxid']
  401. return union_id
  402. def get_user_tags(self, user_id: str) -> List[str]:
  403. union_id = self.get_user_union_id(user_id)
  404. if not union_id:
  405. return []
  406. sql = f"""
  407. select b.tag_id, c.`tag_name` from `we_com_user` as a
  408. join `we_com_user_with_tag` as b
  409. join `we_com_tag` as c
  410. on a.`id` = b.`user_id`
  411. and b.`tag_id` = c.id
  412. where a.union_id = '{union_id}' """
  413. tag_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
  414. tag_names = [tag['tag_name'] for tag in tag_data]
  415. return tag_names
  416. def stop_user_daily_push(self, user_id: str) -> bool:
  417. try:
  418. union_id = self.get_user_union_id(user_id)
  419. if not union_id:
  420. return False
  421. sql = f"UPDATE {self.user_table} SET group_msg_disabled = 1 WHERE union_id = %s"
  422. rows = self.wecom_db.execute(sql, (union_id, ))
  423. if rows > 0:
  424. return True
  425. else:
  426. return False
  427. except Exception as e:
  428. logger.error(f"stop_user_daily_push failed: {e}")
  429. return False