|
@@ -1,13 +1,18 @@
|
|
#! /usr/bin/env python
|
|
#! /usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
# -*- coding: utf-8 -*-
|
|
# vim:fenc=utf-8
|
|
# vim:fenc=utf-8
|
|
-
|
|
|
|
|
|
+import logging
|
|
from typing import Dict, Optional, Tuple, Any
|
|
from typing import Dict, Optional, Tuple, Any
|
|
import json
|
|
import json
|
|
import time
|
|
import time
|
|
import os
|
|
import os
|
|
import abc
|
|
import abc
|
|
|
|
|
|
|
|
+import pymysql.cursors
|
|
|
|
+
|
|
|
|
+import configs
|
|
|
|
+from database import MySQLManager
|
|
|
|
+
|
|
class UserManager(abc.ABC):
|
|
class UserManager(abc.ABC):
|
|
@abc.abstractmethod
|
|
@abc.abstractmethod
|
|
def get_user_profile(self, user_id) -> Dict:
|
|
def get_user_profile(self, user_id) -> Dict:
|
|
@@ -21,6 +26,44 @@ class UserManager(abc.ABC):
|
|
def list_all_users(self):
|
|
def list_all_users(self):
|
|
pass
|
|
pass
|
|
|
|
|
|
|
|
+ @staticmethod
|
|
|
|
+ def get_default_profile(**kwargs) -> Dict:
|
|
|
|
+ default_profile = {
|
|
|
|
+ "name": "",
|
|
|
|
+ "nickname": "",
|
|
|
|
+ "preferred_nickname": "",
|
|
|
|
+ "age": 0,
|
|
|
|
+ "region": '',
|
|
|
|
+ "interests": [],
|
|
|
|
+ "family_members": {},
|
|
|
|
+ "health_conditions": [],
|
|
|
|
+ "medications": [],
|
|
|
|
+ "reminder_preferences": {
|
|
|
|
+ "medication": True,
|
|
|
|
+ "health": True,
|
|
|
|
+ "weather": True,
|
|
|
|
+ "news": False
|
|
|
|
+ },
|
|
|
|
+ "interaction_style": "standard", # standard, verbose, concise
|
|
|
|
+ "interaction_frequency": "medium", # low, medium, high
|
|
|
|
+ "last_topics": [],
|
|
|
|
+ "created_at": int(time.time() * 1000),
|
|
|
|
+ "human_intervention_history": []
|
|
|
|
+ }
|
|
|
|
+ for key, value in kwargs.items():
|
|
|
|
+ if key in default_profile:
|
|
|
|
+ default_profile[key] = value
|
|
|
|
+ return default_profile
|
|
|
|
+
|
|
|
|
+class UserRelationManager(abc.ABC):
|
|
|
|
+ @abc.abstractmethod
|
|
|
|
+ def list_staffs(self):
|
|
|
|
+ pass
|
|
|
|
+
|
|
|
|
+ @abc.abstractmethod
|
|
|
|
+ def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
|
|
|
|
+ pass
|
|
|
|
+
|
|
class LocalUserManager(UserManager):
|
|
class LocalUserManager(UserManager):
|
|
def get_user_profile(self, user_id) -> Dict:
|
|
def get_user_profile(self, user_id) -> Dict:
|
|
"""加载用户个人资料,如不存在则创建默认资料。主要用于本地调试"""
|
|
"""加载用户个人资料,如不存在则创建默认资料。主要用于本地调试"""
|
|
@@ -29,28 +72,7 @@ class LocalUserManager(UserManager):
|
|
return json.load(f)
|
|
return json.load(f)
|
|
except FileNotFoundError:
|
|
except FileNotFoundError:
|
|
# 创建默认用户资料
|
|
# 创建默认用户资料
|
|
- default_profile = {
|
|
|
|
- "name": "",
|
|
|
|
- "nickname": "",
|
|
|
|
- "preferred_nickname": "",
|
|
|
|
- "age": 0,
|
|
|
|
- "region": '',
|
|
|
|
- "interests": [],
|
|
|
|
- "family_members": {},
|
|
|
|
- "health_conditions": [],
|
|
|
|
- "medications": [],
|
|
|
|
- "reminder_preferences": {
|
|
|
|
- "medication": True,
|
|
|
|
- "health": True,
|
|
|
|
- "weather": True,
|
|
|
|
- "news": False
|
|
|
|
- },
|
|
|
|
- "interaction_style": "standard", # standard, verbose, concise
|
|
|
|
- "interaction_frequency": "medium", # low, medium, high
|
|
|
|
- "last_topics": [],
|
|
|
|
- "created_at": int(time.time() * 1000),
|
|
|
|
- "human_intervention_history": []
|
|
|
|
- }
|
|
|
|
|
|
+ default_profile = self.get_default_profile()
|
|
self.save_user_profile(user_id, default_profile)
|
|
self.save_user_profile(user_id, default_profile)
|
|
return default_profile
|
|
return default_profile
|
|
|
|
|
|
@@ -61,9 +83,45 @@ class LocalUserManager(UserManager):
|
|
json.dump(profile, f, ensure_ascii=False, indent=2)
|
|
json.dump(profile, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
def list_all_users(self):
|
|
def list_all_users(self):
|
|
- json_files = []
|
|
|
|
|
|
+ user_ids = []
|
|
for root, dirs, files in os.walk('user_profiles/'):
|
|
for root, dirs, files in os.walk('user_profiles/'):
|
|
for file in files:
|
|
for file in files:
|
|
if file.endswith('.json'):
|
|
if file.endswith('.json'):
|
|
- json_files.append(os.path.splitext(file)[0])
|
|
|
|
- return json_files
|
|
|
|
|
|
+ user_ids.append(os.path.splitext(file)[0])
|
|
|
|
+ return user_ids
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class MySQLUserManager(UserManager):
|
|
|
|
+ def __init__(self, db_config, table_name):
|
|
|
|
+ self.db = MySQLManager(db_config)
|
|
|
|
+ self.table_name = table_name
|
|
|
|
+
|
|
|
|
+ def get_user_profile(self, user_id) -> Dict:
|
|
|
|
+ sql = f"SELECT name, wxid, profile_data_v1 FROM {self.table_name} WHERE third_party_user_id = {user_id}"
|
|
|
|
+ data = self.db.select(sql, pymysql.cursors.DictCursor)
|
|
|
|
+ if not data:
|
|
|
|
+ logging.error(f"user[{user_id}] not found")
|
|
|
|
+ return {}
|
|
|
|
+ data = data[0]
|
|
|
|
+ if not data['profile_data_v1']:
|
|
|
|
+ logging.warning(f"user[{user_id}] profile not found, create a default one")
|
|
|
|
+ default_profile = self.get_default_profile(nickname=data['name'])
|
|
|
|
+ self.save_user_profile(user_id, default_profile)
|
|
|
|
+ return json.loads(data['profile_data_v1'])
|
|
|
|
+
|
|
|
|
+ def save_user_profile(self, user_id, profile: Dict) -> None:
|
|
|
|
+ if not user_id:
|
|
|
|
+ raise Exception("Invalid user_id: {}".format(user_id))
|
|
|
|
+ sql = f"UPDATE {self.table_name} SET profile_data_v1 = %s WHERE third_party_user_id = {user_id}"
|
|
|
|
+ self.db.execute(sql, (json.dumps(profile),))
|
|
|
|
+
|
|
|
|
+ def list_all_users(self):
|
|
|
|
+ sql = f"SELECT third_party_user_id FROM {self.table_name}"
|
|
|
|
+ data = self.db.select(sql, pymysql.cursors.DictCursor)
|
|
|
|
+ return [user['third_party_user_id'] for user in data]
|
|
|
|
+
|
|
|
|
+if __name__ == '__main__':
|
|
|
|
+ db_config = configs.get()['storage']['user']
|
|
|
|
+ user_manager = MySQLUserManager(db_config['mysql'], db_config['table'])
|
|
|
|
+ user_profile = user_manager.get_user_profile('7881301263964433')
|
|
|
|
+ print(user_profile)
|