user_manager.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. from pqai_agent.logging_service import logger
  5. from typing import Dict, Optional, List
  6. import json
  7. import time
  8. import os
  9. import abc
  10. import pymysql.cursors
  11. from pqai_agent import configs
  12. from pqai_agent.database import MySQLManager
  13. class UserManager(abc.ABC):
  14. @abc.abstractmethod
  15. def get_user_profile(self, user_id) -> Dict:
  16. pass
  17. @abc.abstractmethod
  18. def save_user_profile(self, user_id, profile: Dict) -> None:
  19. pass
  20. @abc.abstractmethod
  21. def list_all_users(self):
  22. pass
  23. @abc.abstractmethod
  24. def get_staff_profile(self, staff_id) -> Dict:
  25. # FIXME(zhoutian): 重新设计用户和员工数据管理模型
  26. pass
  27. @staticmethod
  28. def get_default_profile(**kwargs) -> Dict:
  29. default_profile = {
  30. "name": "",
  31. "nickname": "",
  32. "avatar": "",
  33. "preferred_nickname": "",
  34. "gender": "未知",
  35. "age": 0,
  36. "region": "",
  37. "interests": [],
  38. "family_members": {},
  39. "health_conditions": [],
  40. "medications": [],
  41. "reminder_preferences": {
  42. "medication": True,
  43. "health": True,
  44. "weather": True,
  45. "news": False,
  46. },
  47. "interaction_style": "standard", # standard, verbose, concise
  48. "interaction_frequency": "medium", # low, medium, high
  49. "last_topics": [],
  50. "created_at": int(time.time() * 1000),
  51. "human_intervention_history": [],
  52. }
  53. for key, value in kwargs.items():
  54. if key in default_profile:
  55. default_profile[key] = value
  56. return default_profile
  57. def list_users(self, **kwargs) -> List[Dict]:
  58. pass
  59. class UserRelationManager(abc.ABC):
  60. @abc.abstractmethod
  61. def list_staffs(self):
  62. pass
  63. @abc.abstractmethod
  64. def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
  65. pass
  66. @abc.abstractmethod
  67. def list_staff_users(self, staff_id: str = None, tag_id: int = None) -> List[Dict]:
  68. pass
  69. @abc.abstractmethod
  70. def get_user_tags(self, user_id: str) -> List[str]:
  71. pass
  72. @abc.abstractmethod
  73. def stop_user_daily_push(self, user_id: str) -> bool:
  74. pass
  75. class LocalUserManager(UserManager):
  76. def get_user_profile(self, user_id) -> Dict:
  77. """加载用户个人资料,如不存在则创建默认资料。主要用于本地调试"""
  78. default_profile = self.get_default_profile()
  79. try:
  80. with open(f"user_profiles/{user_id}.json", "r", encoding="utf-8") as f:
  81. profile = json.load(f)
  82. entry_added = False
  83. for key, value in default_profile.items():
  84. if key not in profile:
  85. logger.debug(
  86. f"user[{user_id}] add profile key[{key}] value[{value}]"
  87. )
  88. profile[key] = value
  89. entry_added = True
  90. if entry_added:
  91. self.save_user_profile(user_id, profile)
  92. return profile
  93. except FileNotFoundError:
  94. # 创建默认用户资料
  95. self.save_user_profile(user_id, default_profile)
  96. return default_profile
  97. def save_user_profile(self, user_id, profile: Dict) -> None:
  98. if not user_id:
  99. raise Exception("Invalid user_id: {}".format(user_id))
  100. with open(f"user_profiles/{user_id}.json", "w", encoding="utf-8") as f:
  101. json.dump(profile, f, ensure_ascii=False, indent=2)
  102. def list_all_users(self):
  103. user_ids = []
  104. for root, dirs, files in os.walk("../user_profiles/"):
  105. for file in files:
  106. if file.endswith(".json"):
  107. user_ids.append(os.path.splitext(file)[0])
  108. return user_ids
  109. def get_staff_profile(self, staff_id) -> Dict:
  110. try:
  111. with open(f"user_profiles/{staff_id}.json", "r", encoding="utf-8") as f:
  112. profile = json.load(f)
  113. entry_added = False
  114. if entry_added:
  115. self.save_user_profile(staff_id, profile)
  116. return profile
  117. except Exception as e:
  118. logger.error("staff profile not found: {}".format(e))
  119. return {}
  120. def list_users(self, **kwargs) -> List[Dict]:
  121. pass
  122. class MySQLUserManager(UserManager):
  123. PROFILE_EXCLUDE_ITEMS = [
  124. "avatar",
  125. ]
  126. def __init__(self, db_config, table_name, staff_table):
  127. self.db = MySQLManager(db_config)
  128. self.table_name = table_name
  129. self.staff_table = staff_table
  130. def get_user_profile(self, user_id) -> Dict:
  131. sql = (
  132. f"SELECT name, wxid, profile_data_v1, gender, iconurl as avatar"
  133. f" FROM {self.table_name} WHERE third_party_user_id = {user_id}"
  134. )
  135. data = self.db.select(sql, pymysql.cursors.DictCursor)
  136. if not data:
  137. logger.error(f"user[{user_id}] not found")
  138. return {}
  139. data = data[0]
  140. gender_map = {0: "未知", 1: "男", 2: "女", None: "未知"}
  141. gender = gender_map[data["gender"]]
  142. default_profile = self.get_default_profile(
  143. nickname=data["name"], gender=gender, avatar=data["avatar"]
  144. )
  145. if not data["profile_data_v1"]:
  146. logger.warning(f"user[{user_id}] profile not found, create a default one")
  147. self.save_user_profile(user_id, default_profile)
  148. return default_profile
  149. else:
  150. profile = json.loads(data["profile_data_v1"])
  151. # 资料条目有增加时,需合并更新
  152. entry_added = False
  153. for key, value in default_profile.items():
  154. if key not in profile:
  155. # logger.debug(f"user[{user_id}] add profile key[{key}] value[{value}]")
  156. profile[key] = value
  157. entry_added = True
  158. if entry_added:
  159. self.save_user_profile(user_id, profile)
  160. return profile
  161. def save_user_profile(self, user_id, profile: Dict) -> None:
  162. if not user_id:
  163. raise Exception("Invalid user_id: {}".format(user_id))
  164. if configs.get().get("debug_flags", {}).get("disable_database_write", False):
  165. return
  166. profile = profile.copy()
  167. for name in self.PROFILE_EXCLUDE_ITEMS:
  168. profile.pop(name, None)
  169. sql = f"UPDATE {self.table_name} SET profile_data_v1 = %s WHERE third_party_user_id = {user_id}"
  170. self.db.execute(sql, (json.dumps(profile),))
  171. def list_all_users(self):
  172. sql = f"SELECT third_party_user_id FROM {self.table_name}"
  173. data = self.db.select(sql, pymysql.cursors.DictCursor)
  174. return [user["third_party_user_id"] for user in data]
  175. def get_staff_profile(self, staff_id) -> Dict:
  176. if not self.staff_table:
  177. raise Exception("staff_table is not set")
  178. return self.get_staff_profile_v3(staff_id)
  179. def get_staff_profile_v1(self, staff_id) -> Dict:
  180. sql = (
  181. f"SELECT agent_name, agent_gender, agent_age, agent_region, agent_profile "
  182. f"FROM {self.staff_table} WHERE third_party_user_id = '{staff_id}'"
  183. )
  184. data = self.db.select(sql, pymysql.cursors.DictCursor)
  185. if not data:
  186. logger.error(f"staff[{staff_id}] not found")
  187. return {}
  188. profile = data[0]
  189. # 转换性别格式
  190. gender_map = {0: "未知", 1: "男", 2: "女", None: "未知"}
  191. profile["agent_gender"] = gender_map[profile["agent_gender"]]
  192. return profile
  193. def get_staff_profile_v2(self, staff_id) -> Dict:
  194. sql = (
  195. f"SELECT agent_name as name, agent_gender as gender, agent_age as age, agent_region as region, agent_profile "
  196. f"FROM {self.staff_table} WHERE third_party_user_id = '{staff_id}'"
  197. )
  198. data = self.db.select(sql, pymysql.cursors.DictCursor)
  199. if not data:
  200. logger.error(f"staff[{staff_id}] not found")
  201. return {}
  202. profile = data[0]
  203. # 转换性别格式
  204. gender_map = {0: "未知", 1: "男", 2: "女", None: "未知"}
  205. profile["gender"] = gender_map[profile["gender"]]
  206. # 合并JSON字段(新版本)数据
  207. if profile["agent_profile"]:
  208. detail_profile = json.loads(profile["agent_profile"])
  209. profile.update(detail_profile)
  210. # 去除原始字段
  211. profile.pop("agent_profile", None)
  212. return profile
  213. def get_staff_profile_v3(self, staff_id) -> Dict:
  214. sql = (
  215. f"SELECT agent_profile "
  216. f"FROM {self.staff_table} WHERE third_party_user_id = '{staff_id}'"
  217. )
  218. data = self.db.select(sql)
  219. if not data:
  220. logger.error(f"staff[{staff_id}] not found")
  221. return {}
  222. profile_str = data[0][0]
  223. if not profile_str:
  224. return {}
  225. profile = json.loads(profile_str)
  226. return profile
  227. def save_staff_profile(self, staff_id: str, profile: Dict):
  228. # 正常情况下不应该有此操作
  229. if not self.staff_table:
  230. raise Exception("staff_table is not set")
  231. if not staff_id:
  232. raise Exception("Invalid staff_id: {}".format(staff_id))
  233. sql = f"UPDATE {self.staff_table} SET agent_profile = %s WHERE third_party_user_id = '{staff_id}'"
  234. self.db.execute(sql, (json.dumps(profile),))
  235. def list_users(self, **kwargs) -> List[Dict]:
  236. user_union_id = kwargs.get("user_union_id", None)
  237. user_name = kwargs.get("user_name", None)
  238. if not user_union_id and not user_name:
  239. raise Exception("user_union_id or user_name is required")
  240. sql = f"SELECT third_party_user_id, wxid, name, iconurl, gender FROM {self.table_name} WHERE 1=1 "
  241. if user_name:
  242. sql += f"AND name = '{user_name}' COLLATE utf8mb4_bin "
  243. if user_union_id:
  244. sql += f"AND wxid = '{user_union_id}' "
  245. data = self.db.select(sql, pymysql.cursors.DictCursor)
  246. return data
  247. def get_staff_sessions(
  248. self,
  249. staff_id,
  250. page_id: int = 1,
  251. page_size: int = 10,
  252. session_type: str = "default",
  253. ) -> List[Dict]:
  254. """
  255. :param page_size:
  256. :param page_id:
  257. :param session_type:
  258. :param staff_id:
  259. :return:
  260. """
  261. match session_type:
  262. case "active":
  263. sql = f"""
  264. select staff_id, current_state, user_id
  265. from agent_state
  266. where staff_id = %s and update_timestamp >= DATE_SUB(NOW(), INTERVAL 2 HOUR)
  267. order by update_timestamp desc;
  268. """
  269. case "human_intervention":
  270. sql = f"""
  271. select staff_id, current_state, user_id
  272. from agent_state
  273. where staff_id = %s and current_state = 5 order by update_timestamp desc;
  274. """
  275. case _:
  276. sql = f"""
  277. select t1.staff_id, t1.current_state, t1.user_id, t2.name, t2.iconurl
  278. from agent_state t1 join third_party_user t2
  279. on t1.user_id = t2.third_party_user_id
  280. where t1.staff_id = %s
  281. order by
  282. IF(t1.current_state = 5, 0, 1),
  283. t1.update_timestamp desc
  284. limit {page_size + 1} offset {page_size * (page_id - 1)};
  285. """
  286. staff_sessions = self.db.select(
  287. sql, cursor_type=pymysql.cursors.DictCursor, args=(staff_id,)
  288. )
  289. return staff_sessions
  290. def get_staff_sessions_summary_v1(
  291. self, staff_id, page_id: int, page_size: int, status: int
  292. ) -> Dict:
  293. """
  294. :param status: staff status(0: unemployed, 1: employed)
  295. :param staff_id: staff
  296. :param page_id: page id
  297. :param page_size: page size
  298. :return:
  299. :todo: 未使用 Mysql 连接池,每次查询均需要与 MySQL 建立连接,性能较低,需要优化
  300. """
  301. if not staff_id:
  302. get_staff_query = f"""
  303. select third_party_user_id, name from {self.staff_table} where status = %s
  304. limit %s offset %s;
  305. """
  306. staff_id_list = self.db.select(
  307. sql=get_staff_query,
  308. cursor_type=pymysql.cursors.DictCursor,
  309. args=(status, page_size + 1, (page_id - 1) * page_size),
  310. )
  311. if not staff_id_list:
  312. return {}
  313. if len(staff_id_list) > page_size:
  314. has_next_page = True
  315. next_page_id = page_id + 1
  316. staff_id_list = staff_id_list[:page_size]
  317. else:
  318. has_next_page = False
  319. next_page_id = None
  320. else:
  321. get_staff_query = f"""
  322. select third_party_user_id, name from {self.staff_table}
  323. where status = %s and third_party_user_id = %s;
  324. """
  325. staff_id_list = self.db.select(
  326. sql=get_staff_query,
  327. cursor_type=pymysql.cursors.DictCursor,
  328. args=(status, staff_id),
  329. )
  330. if not staff_id_list:
  331. return {}
  332. has_next_page = False
  333. next_page_id = None
  334. response_data = [
  335. {
  336. "staff_id": staff["third_party_user_id"],
  337. "staff_name": staff["name"],
  338. "active_sessions": len(
  339. self.get_staff_sessions(
  340. staff["third_party_user_id"], session_type="active"
  341. )
  342. ),
  343. "human_intervention_sessions": len(
  344. self.get_staff_sessions(
  345. staff["third_party_user_id"], session_type="human_intervention"
  346. )
  347. ),
  348. }
  349. for staff in staff_id_list
  350. ]
  351. return {
  352. "has_next_page": has_next_page,
  353. "next_page_id": next_page_id,
  354. "data": response_data,
  355. }
  356. def get_staff_session_list_v1(self, staff_id, page_id: int, page_size: int) -> Dict:
  357. """
  358. :param page_size:
  359. :param page_id:
  360. :param staff_id:
  361. :return:
  362. """
  363. session_list = self.get_staff_sessions(staff_id, page_id, page_size)
  364. if len(session_list) > page_size:
  365. has_next_page = True
  366. next_page_id = page_id + 1
  367. session_list = session_list[:page_size]
  368. else:
  369. has_next_page = False
  370. next_page_id = None
  371. response_data = []
  372. for session in session_list:
  373. temp_obj = {}
  374. user_id = session["user_id"]
  375. room_id = ":".join(["private", staff_id, user_id])
  376. select_query = f"""select content, max(sendtime) as max_timestamp from qywx_chat_history where roomid = %s;"""
  377. last_message = self.db.select(
  378. sql=select_query,
  379. cursor_type=pymysql.cursors.DictCursor,
  380. args=(room_id,),
  381. )
  382. if not last_message:
  383. temp_obj["message"] = ""
  384. temp_obj["timestamp"] = 0
  385. else:
  386. temp_obj["message"] = last_message[0]["content"]
  387. temp_obj["timestamp"] = last_message[0]["max_timestamp"]
  388. temp_obj["customer_id"] = user_id
  389. temp_obj["customer_name"] = session["name"]
  390. temp_obj["avatar"] = session["iconurl"]
  391. response_data.append(temp_obj)
  392. return {
  393. "staff_id": staff_id,
  394. "has_next_page": has_next_page,
  395. "next_page_id": next_page_id,
  396. "data": response_data,
  397. }
  398. def get_staff_list(self, page_id: int, page_size: int) -> Dict:
  399. """
  400. :param page_size:
  401. :param page_id:
  402. :return:
  403. """
  404. sql = f"""
  405. select t1.third_party_user_id as staff_id, t1.name as staff_name, t2.iconurl as avatar
  406. from qywx_employee t1 left join third_party_user t2
  407. on t1.third_party_user_id = t2.third_party_user_id
  408. limit %s offset %s;
  409. """
  410. staff_list = self.db.select(
  411. sql=sql,
  412. cursor_type=pymysql.cursors.DictCursor,
  413. args=(page_size + 1, page_size * (page_id - 1)),
  414. )
  415. if len(staff_list) > page_size:
  416. has_next_page = True
  417. next_page_id = page_id + 1
  418. staff_list = staff_list[:page_size]
  419. else:
  420. has_next_page = False
  421. next_page_id = None
  422. return {
  423. "has_next_page": has_next_page,
  424. "next_page": next_page_id,
  425. "data": staff_list,
  426. }
  427. def get_conversation_list_v1(
  428. self, staff_id: str, customer_id: str, page: Optional[int]
  429. ):
  430. """
  431. :param staff_id:
  432. :param customer_id:
  433. :param page: timestamp
  434. :return:
  435. """
  436. room_id = ":".join(["private", staff_id, customer_id])
  437. page_size = 20
  438. if not page:
  439. fetch_query = f"""
  440. select t1.sender, t2.name, t1.sendtime, t1.content, t2.iconurl, t1.msg_type
  441. from qywx_chat_history t1
  442. join third_party_user t2 on t1.sender = t2.third_party_user_id
  443. where roomid = %s
  444. order by sendtime desc
  445. limit %s;
  446. """
  447. messages = self.db.select(
  448. sql=fetch_query,
  449. cursor_type=pymysql.cursors.DictCursor,
  450. args=(room_id, page_size + 1),
  451. )
  452. else:
  453. fetch_query = f"""
  454. select t1.sender, t2.name, t1.sendtime, t1.content, t2.iconurl, t1.msg_type
  455. from qywx_chat_history t1
  456. join third_party_user t2 on t1.sender = t2.third_party_user_id
  457. where t1.roomid = %s and t1.sendtime <= %s
  458. order by sendtime desc
  459. limit %s;
  460. """
  461. messages = self.db.select(
  462. sql=fetch_query,
  463. cursor_type=pymysql.cursors.DictCursor,
  464. args=(room_id, page, page_size + 1),
  465. )
  466. if messages:
  467. if len(messages) > page_size:
  468. has_next_page = True
  469. next_page = messages[-1]["sendtime"]
  470. else:
  471. has_next_page = False
  472. next_page = None
  473. response_data = [
  474. {
  475. "sender_id": message["sender"],
  476. "sender_name": message["name"],
  477. "avatar": message["iconurl"],
  478. "content": message["content"],
  479. "timestamp": message["sendtime"],
  480. "role": "customer" if message["sender"] == customer_id else "staff",
  481. "message_type": message["msg_type"],
  482. }
  483. for message in messages[ :page_size]
  484. ]
  485. return {
  486. "staff_id": staff_id,
  487. "customer_id": customer_id,
  488. "has_next_page": has_next_page,
  489. "next_page": next_page,
  490. "data": response_data,
  491. }
  492. else:
  493. has_next_page = False
  494. next_page = None
  495. return {
  496. "staff_id": staff_id,
  497. "customer_id": customer_id,
  498. "has_next_page": has_next_page,
  499. "next_page": next_page,
  500. "data": [],
  501. }
  502. class LocalUserRelationManager(UserRelationManager):
  503. def __init__(self):
  504. pass
  505. def list_staffs(self):
  506. return [
  507. {
  508. "third_party_user_id": "1688855931724582",
  509. "name": "",
  510. "wxid": "ShengHuoLeQu",
  511. "agent_name": "小芳",
  512. }
  513. ]
  514. def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
  515. return []
  516. def list_staff_users(self, staff_id: str = None, tag_id: int = None):
  517. user_ids = [
  518. "7881299453089278",
  519. "7881299453132630",
  520. "7881299454186909",
  521. "7881299455103430",
  522. "7881299455173476",
  523. "7881299456216398",
  524. "7881299457990953",
  525. "7881299461167644",
  526. "7881299463002136",
  527. "7881299464081604",
  528. "7881299465121735",
  529. "7881299465998082",
  530. "7881299466221881",
  531. "7881299467152300",
  532. "7881299470051791",
  533. "7881299470112816",
  534. "7881299471149567",
  535. "7881299471168030",
  536. "7881299471277650",
  537. "7881299473321703",
  538. ]
  539. user_ids = user_ids[:5]
  540. return [
  541. {"staff_id": "1688855931724582", "user_id": "7881299670930896"},
  542. *[
  543. {"staff_id": "1688855931724582", "user_id": user_id}
  544. for user_id in user_ids
  545. ],
  546. ]
  547. def get_user_tags(self, user_id: str):
  548. return []
  549. def stop_user_daily_push(self, user_id: str) -> bool:
  550. return True
  551. class MySQLUserRelationManager(UserRelationManager):
  552. def __init__(
  553. self,
  554. agent_db_config,
  555. wecom_db_config,
  556. agent_staff_table,
  557. agent_user_table,
  558. staff_table,
  559. relation_table,
  560. user_table,
  561. ):
  562. # FIXME(zhoutian): 因为现在数据库表不统一,需要从两个库读取
  563. self.agent_db = MySQLManager(agent_db_config)
  564. self.wecom_db = MySQLManager(wecom_db_config)
  565. self.agent_staff_table = agent_staff_table
  566. self.staff_table = staff_table
  567. self.relation_table = relation_table
  568. self.agent_user_table = agent_user_table
  569. self.user_table = user_table
  570. def list_staffs(self):
  571. sql = f"SELECT third_party_user_id, name, wxid, agent_name FROM {self.agent_staff_table} WHERE status = 1"
  572. data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
  573. return data
  574. def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
  575. return []
  576. def list_staff_users(self, staff_id: str = None, tag_id: int = None):
  577. sql = f"SELECT third_party_user_id, wxid FROM {self.agent_staff_table} WHERE status = 1"
  578. if staff_id:
  579. sql += f" AND third_party_user_id = '{staff_id}'"
  580. agent_staff_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
  581. if not agent_staff_data:
  582. return []
  583. ret = []
  584. for agent_staff in agent_staff_data:
  585. wxid = agent_staff["wxid"]
  586. sql = f"SELECT id FROM {self.staff_table} WHERE carrier_id = '{wxid}'"
  587. staff_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
  588. if not staff_data:
  589. logger.error(f"staff[{wxid}] not found in wecom database")
  590. continue
  591. staff_id = staff_data[0]["id"]
  592. sql = f"SELECT user_id FROM {self.relation_table} WHERE staff_id = '{staff_id}' AND is_delete = 0"
  593. user_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
  594. if not user_data:
  595. logger.warning(f"staff[{wxid}] has no user")
  596. continue
  597. user_ids = tuple(user["user_id"] for user in user_data)
  598. sql = f"SELECT union_id FROM {self.user_table} WHERE id IN {str(user_ids)} AND union_id is not null"
  599. if tag_id:
  600. sql += f" AND id in (SELECT distinct user_id FROM we_com_user_with_tag WHERE tag_id = {tag_id} and is_delete = 0)"
  601. user_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
  602. if not user_data:
  603. logger.warning(f"staff[{wxid}] users not found in wecom database")
  604. continue
  605. user_union_ids = tuple(user["union_id"] for user in user_data)
  606. batch_size = 500
  607. n_batches = (len(user_union_ids) + batch_size - 1) // batch_size
  608. agent_user_data = []
  609. for i in range(n_batches):
  610. idx_begin = i * batch_size
  611. idx_end = min((i + 1) * batch_size, len(user_union_ids))
  612. batch_union_ids = user_union_ids[idx_begin:idx_end]
  613. sql = f"SELECT third_party_user_id, wxid FROM {self.agent_user_table} WHERE wxid IN {str(batch_union_ids)}"
  614. batch_agent_user_data = self.agent_db.select(
  615. sql, pymysql.cursors.DictCursor
  616. )
  617. if len(agent_user_data) != len(batch_union_ids):
  618. # logger.debug(f"staff[{wxid}] some users not found in agent database")
  619. pass
  620. agent_user_data.extend(batch_agent_user_data)
  621. staff_user_pairs = [
  622. {
  623. "staff_id": agent_staff["third_party_user_id"],
  624. "user_id": agent_user["third_party_user_id"],
  625. }
  626. for agent_user in agent_user_data
  627. ]
  628. ret.extend(staff_user_pairs)
  629. return ret
  630. def get_user_union_id(self, user_id: str) -> Optional[str]:
  631. sql = f"SELECT wxid FROM {self.agent_user_table} WHERE third_party_user_id = '{user_id}' AND wxid is not null"
  632. user_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
  633. if not user_data:
  634. logger.error(f"user[{user_id}] has no union id")
  635. return None
  636. union_id = user_data[0]["wxid"]
  637. return union_id
  638. def get_user_tags(self, user_id: str) -> List[str]:
  639. union_id = self.get_user_union_id(user_id)
  640. if not union_id:
  641. return []
  642. sql = f"""
  643. select b.tag_id, c.`tag_name` from `we_com_user` as a
  644. join `we_com_user_with_tag` as b
  645. join `we_com_tag` as c
  646. on a.`id` = b.`user_id`
  647. and b.`tag_id` = c.id
  648. where a.union_id = '{union_id}' """
  649. tag_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
  650. tag_names = [tag["tag_name"] for tag in tag_data]
  651. return tag_names
  652. def stop_user_daily_push(self, user_id: str) -> bool:
  653. try:
  654. union_id = self.get_user_union_id(user_id)
  655. if not union_id:
  656. return False
  657. sql = f"UPDATE {self.user_table} SET group_msg_disabled = 1 WHERE union_id = %s"
  658. rows = self.wecom_db.execute(sql, (union_id,))
  659. if rows > 0:
  660. return True
  661. else:
  662. return False
  663. except Exception as e:
  664. logger.error(f"stop_user_daily_push failed: {e}")
  665. return False
  666. if __name__ == "__main__":
  667. config = configs.get()
  668. user_db_config = config["storage"]["user"]
  669. staff_db_config = config["storage"]["staff"]
  670. user_manager = MySQLUserManager(
  671. user_db_config["mysql"], user_db_config["table"], staff_db_config["table"]
  672. )
  673. user_profile = user_manager.get_user_profile("7881301263964433")
  674. print(user_profile)
  675. wecom_db_config = config["storage"]["user_relation"]
  676. user_relation_manager = MySQLUserRelationManager(
  677. user_db_config["mysql"],
  678. wecom_db_config["mysql"],
  679. config["storage"]["staff"]["table"],
  680. user_db_config["table"],
  681. wecom_db_config["table"]["staff"],
  682. wecom_db_config["table"]["relation"],
  683. wecom_db_config["table"]["user"],
  684. )
  685. # all_staff_users = user_relation_manager.list_staff_users()
  686. user_tags = user_relation_manager.get_user_tags("7881302078008656")
  687. print(user_tags)