123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127 |
- #! /usr/bin/env python
- # -*- coding: utf-8 -*-
- # vim:fenc=utf-8
- import logging
- from typing import Dict, Optional, Tuple, Any
- import json
- import time
- import os
- import abc
- import pymysql.cursors
- import configs
- from database import MySQLManager
- class UserManager(abc.ABC):
- @abc.abstractmethod
- def get_user_profile(self, user_id) -> Dict:
- pass
- @abc.abstractmethod
- def save_user_profile(self, user_id, profile: Dict) -> None:
- pass
- @abc.abstractmethod
- def list_all_users(self):
- pass
- @staticmethod
- def get_default_profile(**kwargs) -> Dict:
- default_profile = {
- "name": "",
- "nickname": "",
- "preferred_nickname": "",
- "age": 0,
- "region": '',
- "interests": [],
- "family_members": {},
- "health_conditions": [],
- "medications": [],
- "reminder_preferences": {
- "medication": True,
- "health": True,
- "weather": True,
- "news": False
- },
- "interaction_style": "standard", # standard, verbose, concise
- "interaction_frequency": "medium", # low, medium, high
- "last_topics": [],
- "created_at": int(time.time() * 1000),
- "human_intervention_history": []
- }
- for key, value in kwargs.items():
- if key in default_profile:
- default_profile[key] = value
- return default_profile
- class UserRelationManager(abc.ABC):
- @abc.abstractmethod
- def list_staffs(self):
- pass
- @abc.abstractmethod
- def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
- pass
- class LocalUserManager(UserManager):
- def get_user_profile(self, user_id) -> Dict:
- """加载用户个人资料,如不存在则创建默认资料。主要用于本地调试"""
- try:
- with open(f"user_profiles/{user_id}.json", "r", encoding="utf-8") as f:
- return json.load(f)
- except FileNotFoundError:
- # 创建默认用户资料
- default_profile = self.get_default_profile()
- self.save_user_profile(user_id, default_profile)
- return default_profile
- def save_user_profile(self, user_id, profile: Dict) -> None:
- if not user_id:
- raise Exception("Invalid user_id: {}".format(user_id))
- with open(f"user_profiles/{user_id}.json", "w", encoding="utf-8") as f:
- json.dump(profile, f, ensure_ascii=False, indent=2)
- def list_all_users(self):
- user_ids = []
- for root, dirs, files in os.walk('user_profiles/'):
- for file in files:
- if file.endswith('.json'):
- user_ids.append(os.path.splitext(file)[0])
- return user_ids
- class MySQLUserManager(UserManager):
- def __init__(self, db_config, table_name):
- self.db = MySQLManager(db_config)
- self.table_name = table_name
- def get_user_profile(self, user_id) -> Dict:
- sql = f"SELECT name, wxid, profile_data_v1 FROM {self.table_name} WHERE third_party_user_id = {user_id}"
- data = self.db.select(sql, pymysql.cursors.DictCursor)
- if not data:
- logging.error(f"user[{user_id}] not found")
- return {}
- data = data[0]
- if not data['profile_data_v1']:
- logging.warning(f"user[{user_id}] profile not found, create a default one")
- default_profile = self.get_default_profile(nickname=data['name'])
- self.save_user_profile(user_id, default_profile)
- return json.loads(data['profile_data_v1'])
- def save_user_profile(self, user_id, profile: Dict) -> None:
- if not user_id:
- raise Exception("Invalid user_id: {}".format(user_id))
- sql = f"UPDATE {self.table_name} SET profile_data_v1 = %s WHERE third_party_user_id = {user_id}"
- self.db.execute(sql, (json.dumps(profile),))
- def list_all_users(self):
- sql = f"SELECT third_party_user_id FROM {self.table_name}"
- data = self.db.select(sql, pymysql.cursors.DictCursor)
- return [user['third_party_user_id'] for user in data]
- if __name__ == '__main__':
- db_config = configs.get()['storage']['user']
- user_manager = MySQLUserManager(db_config['mysql'], db_config['table'])
- user_profile = user_manager.get_user_profile('7881301263964433')
- print(user_profile)
|