#! /usr/bin/env python # -*- coding: utf-8 -*- # vim:fenc=utf-8 from logging_service import logger from typing import Dict, Optional, Tuple, Any, List import json import time import os import abc import pymysql.cursors import configs from database import MySQLManager class UserManager(abc.ABC): @abc.abstractmethod def get_user_profile(self, user_id) -> Dict: pass @abc.abstractmethod def save_user_profile(self, user_id, profile: Dict) -> None: pass @abc.abstractmethod def list_all_users(self): pass @abc.abstractmethod def get_staff_profile(self, staff_id) -> Dict: #FIXME(zhoutian): 重新设计用户和员工数据管理模型 pass @staticmethod def get_default_profile(**kwargs) -> Dict: default_profile = { "name": "", "nickname": "", "preferred_nickname": "", "gender": "未知", "age": 0, "region": '', "interests": [], "family_members": {}, "health_conditions": [], "medications": [], "reminder_preferences": { "medication": True, "health": True, "weather": True, "news": False }, "interaction_style": "standard", # standard, verbose, concise "interaction_frequency": "medium", # low, medium, high "last_topics": [], "created_at": int(time.time() * 1000), "human_intervention_history": [] } for key, value in kwargs.items(): if key in default_profile: default_profile[key] = value return default_profile def list_users(self, **kwargs) -> List[Dict]: pass class UserRelationManager(abc.ABC): @abc.abstractmethod def list_staffs(self): pass @abc.abstractmethod def list_users(self, staff_id: str, page: int = 1, page_size: int = 100): pass @abc.abstractmethod def list_staff_users(self) -> List[Dict]: pass @abc.abstractmethod def get_user_tags(self, user_id: str) -> List[str]: pass class LocalUserManager(UserManager): def get_user_profile(self, user_id) -> Dict: """加载用户个人资料,如不存在则创建默认资料。主要用于本地调试""" default_profile = self.get_default_profile() try: with open(f"user_profiles/{user_id}.json", "r", encoding="utf-8") as f: profile = json.load(f) entry_added = False for key, value in default_profile.items(): if key not in profile: logger.debug(f"user[{user_id}] add profile key[{key}] value[{value}]") profile[key] = value entry_added = True if entry_added: self.save_user_profile(user_id, profile) return profile except FileNotFoundError: # 创建默认用户资料 self.save_user_profile(user_id, default_profile) return default_profile def save_user_profile(self, user_id, profile: Dict) -> None: if not user_id: raise Exception("Invalid user_id: {}".format(user_id)) with open(f"user_profiles/{user_id}.json", "w", encoding="utf-8") as f: json.dump(profile, f, ensure_ascii=False, indent=2) def list_all_users(self): user_ids = [] for root, dirs, files in os.walk('user_profiles/'): for file in files: if file.endswith('.json'): user_ids.append(os.path.splitext(file)[0]) return user_ids def get_staff_profile(self, staff_id) -> Dict: return {} def list_users(self, **kwargs) -> List[Dict]: pass class MySQLUserManager(UserManager): def __init__(self, db_config, table_name, staff_table): self.db = MySQLManager(db_config) self.table_name = table_name self.staff_table = staff_table def get_user_profile(self, user_id) -> Dict: sql = f"SELECT name, wxid, profile_data_v1, gender" \ f" FROM {self.table_name} WHERE third_party_user_id = {user_id}" data = self.db.select(sql, pymysql.cursors.DictCursor) if not data: logger.error(f"user[{user_id}] not found") return {} data = data[0] gender_map = {0: '未知', 1: '男', 2: '女', None: '未知'} gender = gender_map[data['gender']] default_profile = self.get_default_profile(nickname=data['name'], gender=gender) if not data['profile_data_v1']: logger.warning(f"user[{user_id}] profile not found, create a default one") self.save_user_profile(user_id, default_profile) return default_profile else: profile = json.loads(data['profile_data_v1']) # 资料条目有增加时,需合并更新 entry_added = False for key, value in default_profile.items(): if key not in profile: logger.debug(f"user[{user_id}] add profile key[{key}] value[{value}]") profile[key] = value entry_added = True if entry_added: self.save_user_profile(user_id, profile) return profile def save_user_profile(self, user_id, profile: Dict) -> None: if not user_id: raise Exception("Invalid user_id: {}".format(user_id)) if configs.get().get('debug_flags', {}).get('disable_database_write', False): return sql = f"UPDATE {self.table_name} SET profile_data_v1 = %s WHERE third_party_user_id = {user_id}" self.db.execute(sql, (json.dumps(profile),)) def list_all_users(self): sql = f"SELECT third_party_user_id FROM {self.table_name}" data = self.db.select(sql, pymysql.cursors.DictCursor) return [user['third_party_user_id'] for user in data] def get_staff_profile(self, staff_id) -> Dict: if not self.staff_table: raise Exception("staff_table is not set") sql = f"SELECT agent_name, agent_gender, agent_age, agent_region " \ f"FROM {self.staff_table} WHERE third_party_user_id = '{staff_id}'" data = self.db.select(sql, pymysql.cursors.DictCursor) if not data: logger.error(f"staff[{staff_id}] not found") return {} profile = data[0] # 转换性别格式 gender_map = {0: '未知', 1: '男', 2: '女', None: '未知'} profile['agent_gender'] = gender_map[profile['agent_gender']] return profile def list_users(self, **kwargs) -> List[Dict]: user_union_id = kwargs.get('user_union_id', None) user_name = kwargs.get('user_name', None) if not user_union_id and not user_name: raise Exception("user_union_id or user_name is required") sql = f"SELECT third_party_user_id, wxid, name, iconurl, gender FROM {self.table_name} WHERE 1=1 " if user_name: sql += f"AND name = '{user_name}' COLLATE utf8mb4_bin " if user_union_id: sql += f"AND wxid = '{user_union_id}' " data = self.db.select(sql, pymysql.cursors.DictCursor) return data class MySQLUserRelationManager(UserRelationManager): def __init__(self, agent_db_config, wecom_db_config, agent_staff_table, agent_user_table, staff_table, relation_table, user_table): # FIXME(zhoutian): 因为现在数据库表不统一,需要从两个库读取 self.agent_db = MySQLManager(agent_db_config) self.wecom_db = MySQLManager(wecom_db_config) self.agent_staff_table = agent_staff_table self.staff_table = staff_table self.relation_table = relation_table self.agent_user_table = agent_user_table self.user_table = user_table def list_staffs(self): sql = f"SELECT third_party_user_id, name, wxid, agent_name FROM {self.agent_staff_table} WHERE status = 1" data = self.agent_db.select(sql, pymysql.cursors.DictCursor) return data def list_users(self, staff_id: str, page: int = 1, page_size: int = 100): return [] def list_staff_users(self): # FIXME(zhoutian) # 测试期间逻辑,只取一个账号 sql = (f"SELECT third_party_user_id, wxid FROM {self.agent_staff_table} WHERE status = 1" f" AND third_party_user_id in ('1688854492669990', '1688855931724582')") agent_staff_data = self.agent_db.select(sql, pymysql.cursors.DictCursor) if not agent_staff_data: return [] ret = [] for agent_staff in agent_staff_data: wxid = agent_staff['wxid'] sql = f"SELECT id FROM {self.staff_table} WHERE carrier_id = '{wxid}'" staff_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor) if not staff_data: logger.error(f"staff[{wxid}] not found in wecom database") continue staff_id = staff_data[0]['id'] sql = f"SELECT user_id FROM {self.relation_table} WHERE staff_id = '{staff_id}' AND is_delete = 0" user_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor) if not user_data: logger.warning(f"staff[{wxid}] has no user") continue user_ids = tuple(user['user_id'] for user in user_data) # FIXME(zhoutian): 测试期间临时逻辑 if agent_staff['third_party_user_id'] == '1688854492669990': sql = f"SELECT union_id FROM {self.user_table} WHERE id IN {str(user_ids)} AND union_id is not null" else: sql = f"SELECT union_id FROM {self.user_table} WHERE id IN {str(user_ids)} AND union_id is not null" \ f" AND id in (SELECT distinct user_id FROM we_com_user_with_tag WHERE tag_id = 15 and is_delete = 0)" user_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor) if not user_data: logger.warning(f"staff[{wxid}] users not found in wecom database") continue user_union_ids = tuple(user['union_id'] for user in user_data) batch_size = 100 n_batches = (len(user_union_ids) + batch_size - 1) // batch_size agent_user_data = [] for i in range(n_batches): idx_begin = i * batch_size idx_end = min((i + 1) * batch_size, len(user_union_ids)) batch_union_ids = user_union_ids[idx_begin:idx_end] sql = f"SELECT third_party_user_id, wxid FROM {self.agent_user_table} WHERE wxid IN {str(batch_union_ids)}" batch_agent_user_data = self.agent_db.select(sql, pymysql.cursors.DictCursor) if len(agent_user_data) != len(batch_union_ids): # logger.debug(f"staff[{wxid}] some users not found in agent database") pass agent_user_data.extend(batch_agent_user_data) staff_user_pairs = [ { 'staff_id': agent_staff['third_party_user_id'], 'user_id': agent_user['third_party_user_id'] } for agent_user in agent_user_data ] ret.extend(staff_user_pairs) return ret def get_user_tags(self, user_id: str) -> List[str]: sql = f"SELECT wxid FROM {self.agent_user_table} WHERE third_party_user_id = '{user_id}' AND wxid is not null" user_data = self.agent_db.select(sql, pymysql.cursors.DictCursor) if not user_data: logger.error(f"user[{user_id}] has no wxid") return [] user_wxid = user_data[0]['wxid'] sql = f""" select b.tag_id, c.`tag_name` from `we_com_user` as a join `we_com_user_with_tag` as b join `we_com_tag` as c on a.`id` = b.`user_id` and b.`tag_id` = c.id where a.union_id = '{user_wxid}' """ tag_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor) tag_names = [tag['tag_name'] for tag in tag_data] return tag_names if __name__ == '__main__': config = configs.get() user_db_config = config['storage']['user'] staff_db_config = config['storage']['staff'] user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table']) user_profile = user_manager.get_user_profile('7881301263964433') print(user_profile) wecom_db_config = config['storage']['user_relation'] user_relation_manager = MySQLUserRelationManager( user_db_config['mysql'], wecom_db_config['mysql'], config['storage']['staff']['table'], user_db_config['table'], wecom_db_config['table']['staff'], wecom_db_config['table']['relation'], wecom_db_config['table']['user'] ) # all_staff_users = user_relation_manager.list_staff_users() user_tags = user_relation_manager.get_user_tags('7881302078008656') print(user_tags)