#! /usr/bin/env python # -*- coding: utf-8 -*- # vim:fenc=utf-8 import logging from typing import Dict, Optional, Tuple, Any import json import time import os import abc import pymysql.cursors import configs from database import MySQLManager class UserManager(abc.ABC): @abc.abstractmethod def get_user_profile(self, user_id) -> Dict: pass @abc.abstractmethod def save_user_profile(self, user_id, profile: Dict) -> None: pass @abc.abstractmethod def list_all_users(self): pass @staticmethod def get_default_profile(**kwargs) -> Dict: default_profile = { "name": "", "nickname": "", "preferred_nickname": "", "age": 0, "region": '', "interests": [], "family_members": {}, "health_conditions": [], "medications": [], "reminder_preferences": { "medication": True, "health": True, "weather": True, "news": False }, "interaction_style": "standard", # standard, verbose, concise "interaction_frequency": "medium", # low, medium, high "last_topics": [], "created_at": int(time.time() * 1000), "human_intervention_history": [] } for key, value in kwargs.items(): if key in default_profile: default_profile[key] = value return default_profile class UserRelationManager(abc.ABC): @abc.abstractmethod def list_staffs(self): pass @abc.abstractmethod def list_users(self, staff_id: str, page: int = 1, page_size: int = 100): pass class LocalUserManager(UserManager): def get_user_profile(self, user_id) -> Dict: """加载用户个人资料,如不存在则创建默认资料。主要用于本地调试""" try: with open(f"user_profiles/{user_id}.json", "r", encoding="utf-8") as f: return json.load(f) except FileNotFoundError: # 创建默认用户资料 default_profile = self.get_default_profile() self.save_user_profile(user_id, default_profile) return default_profile def save_user_profile(self, user_id, profile: Dict) -> None: if not user_id: raise Exception("Invalid user_id: {}".format(user_id)) with open(f"user_profiles/{user_id}.json", "w", encoding="utf-8") as f: json.dump(profile, f, ensure_ascii=False, indent=2) def list_all_users(self): user_ids = [] for root, dirs, files in os.walk('user_profiles/'): for file in files: if file.endswith('.json'): user_ids.append(os.path.splitext(file)[0]) return user_ids class MySQLUserManager(UserManager): def __init__(self, db_config, table_name): self.db = MySQLManager(db_config) self.table_name = table_name def get_user_profile(self, user_id) -> Dict: sql = f"SELECT name, wxid, profile_data_v1 FROM {self.table_name} WHERE third_party_user_id = {user_id}" data = self.db.select(sql, pymysql.cursors.DictCursor) if not data: logging.error(f"user[{user_id}] not found") return {} data = data[0] if not data['profile_data_v1']: logging.warning(f"user[{user_id}] profile not found, create a default one") default_profile = self.get_default_profile(nickname=data['name']) self.save_user_profile(user_id, default_profile) return json.loads(data['profile_data_v1']) def save_user_profile(self, user_id, profile: Dict) -> None: if not user_id: raise Exception("Invalid user_id: {}".format(user_id)) sql = f"UPDATE {self.table_name} SET profile_data_v1 = %s WHERE third_party_user_id = {user_id}" self.db.execute(sql, (json.dumps(profile),)) def list_all_users(self): sql = f"SELECT third_party_user_id FROM {self.table_name}" data = self.db.select(sql, pymysql.cursors.DictCursor) return [user['third_party_user_id'] for user in data] if __name__ == '__main__': db_config = configs.get()['storage']['user'] user_manager = MySQLUserManager(db_config['mysql'], db_config['table']) user_profile = user_manager.get_user_profile('7881301263964433') print(user_profile)