Browse Source

Update user_manager: add MySQLUserManager

StrayWarrior 2 weeks ago
parent
commit
7cf8884009
1 changed files with 84 additions and 26 deletions
  1. 84 26
      user_manager.py

+ 84 - 26
user_manager.py

@@ -1,13 +1,18 @@
 #! /usr/bin/env python
 #! /usr/bin/env python
 # -*- coding: utf-8 -*-
 # -*- coding: utf-8 -*-
 # vim:fenc=utf-8
 # vim:fenc=utf-8
-
+import logging
 from typing import Dict, Optional, Tuple, Any
 from typing import Dict, Optional, Tuple, Any
 import json
 import json
 import time
 import time
 import os
 import os
 import abc
 import abc
 
 
+import pymysql.cursors
+
+import configs
+from database import MySQLManager
+
 class UserManager(abc.ABC):
 class UserManager(abc.ABC):
     @abc.abstractmethod
     @abc.abstractmethod
     def get_user_profile(self, user_id) -> Dict:
     def get_user_profile(self, user_id) -> Dict:
@@ -21,6 +26,44 @@ class UserManager(abc.ABC):
     def list_all_users(self):
     def list_all_users(self):
         pass
         pass
 
 
+    @staticmethod
+    def get_default_profile(**kwargs) -> Dict:
+        default_profile = {
+            "name": "",
+            "nickname": "",
+            "preferred_nickname": "",
+            "age": 0,
+            "region": '',
+            "interests": [],
+            "family_members": {},
+            "health_conditions": [],
+            "medications": [],
+            "reminder_preferences": {
+                "medication": True,
+                "health": True,
+                "weather": True,
+                "news": False
+            },
+            "interaction_style": "standard",  # standard, verbose, concise
+            "interaction_frequency": "medium",  # low, medium, high
+            "last_topics": [],
+            "created_at": int(time.time() * 1000),
+            "human_intervention_history": []
+        }
+        for key, value in kwargs.items():
+            if key in default_profile:
+                default_profile[key] = value
+        return default_profile
+
+class UserRelationManager(abc.ABC):
+    @abc.abstractmethod
+    def list_staffs(self):
+        pass
+
+    @abc.abstractmethod
+    def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
+        pass
+
 class LocalUserManager(UserManager):
 class LocalUserManager(UserManager):
     def get_user_profile(self, user_id) -> Dict:
     def get_user_profile(self, user_id) -> Dict:
         """加载用户个人资料,如不存在则创建默认资料。主要用于本地调试"""
         """加载用户个人资料,如不存在则创建默认资料。主要用于本地调试"""
@@ -29,28 +72,7 @@ class LocalUserManager(UserManager):
                 return json.load(f)
                 return json.load(f)
         except FileNotFoundError:
         except FileNotFoundError:
             # 创建默认用户资料
             # 创建默认用户资料
-            default_profile = {
-                "name": "",
-                "nickname": "",
-                "preferred_nickname": "",
-                "age": 0,
-                "region": '',
-                "interests": [],
-                "family_members": {},
-                "health_conditions": [],
-                "medications": [],
-                "reminder_preferences": {
-                    "medication": True,
-                    "health": True,
-                    "weather": True,
-                    "news": False
-                },
-                "interaction_style": "standard",  # standard, verbose, concise
-                "interaction_frequency": "medium",  # low, medium, high
-                "last_topics": [],
-                "created_at": int(time.time() * 1000),
-                "human_intervention_history": []
-            }
+            default_profile = self.get_default_profile()
             self.save_user_profile(user_id, default_profile)
             self.save_user_profile(user_id, default_profile)
             return default_profile
             return default_profile
 
 
@@ -61,9 +83,45 @@ class LocalUserManager(UserManager):
             json.dump(profile, f, ensure_ascii=False, indent=2)
             json.dump(profile, f, ensure_ascii=False, indent=2)
 
 
     def list_all_users(self):
     def list_all_users(self):
-        json_files = []
+        user_ids = []
         for root, dirs, files in os.walk('user_profiles/'):
         for root, dirs, files in os.walk('user_profiles/'):
             for file in files:
             for file in files:
                 if file.endswith('.json'):
                 if file.endswith('.json'):
-                    json_files.append(os.path.splitext(file)[0])
-        return json_files
+                    user_ids.append(os.path.splitext(file)[0])
+        return user_ids
+
+
+class MySQLUserManager(UserManager):
+    def __init__(self, db_config, table_name):
+        self.db = MySQLManager(db_config)
+        self.table_name = table_name
+
+    def get_user_profile(self, user_id) -> Dict:
+        sql = f"SELECT name, wxid, profile_data_v1 FROM {self.table_name} WHERE third_party_user_id = {user_id}"
+        data = self.db.select(sql, pymysql.cursors.DictCursor)
+        if not data:
+            logging.error(f"user[{user_id}] not found")
+            return {}
+        data = data[0]
+        if not data['profile_data_v1']:
+            logging.warning(f"user[{user_id}] profile not found, create a default one")
+            default_profile = self.get_default_profile(nickname=data['name'])
+            self.save_user_profile(user_id, default_profile)
+        return json.loads(data['profile_data_v1'])
+
+    def save_user_profile(self, user_id, profile: Dict) -> None:
+        if not user_id:
+            raise Exception("Invalid user_id: {}".format(user_id))
+        sql = f"UPDATE {self.table_name} SET profile_data_v1 = %s WHERE third_party_user_id = {user_id}"
+        self.db.execute(sql, (json.dumps(profile),))
+
+    def list_all_users(self):
+        sql = f"SELECT third_party_user_id FROM {self.table_name}"
+        data = self.db.select(sql, pymysql.cursors.DictCursor)
+        return [user['third_party_user_id'] for user in data]
+
+if __name__ == '__main__':
+    db_config = configs.get()['storage']['user']
+    user_manager = MySQLUserManager(db_config['mysql'], db_config['table'])
+    user_profile = user_manager.get_user_profile('7881301263964433')
+    print(user_profile)