소스 검색

Update user_manager: add MySQLUserManager

StrayWarrior 2 주 전
부모
커밋
7cf8884009
1개의 변경된 파일84개의 추가작업 그리고 26개의 파일을 삭제
  1. 84 26
      user_manager.py

+ 84 - 26
user_manager.py

@@ -1,13 +1,18 @@
 #! /usr/bin/env python
 # -*- coding: utf-8 -*-
 # vim:fenc=utf-8
-
+import logging
 from typing import Dict, Optional, Tuple, Any
 import json
 import time
 import os
 import abc
 
+import pymysql.cursors
+
+import configs
+from database import MySQLManager
+
 class UserManager(abc.ABC):
     @abc.abstractmethod
     def get_user_profile(self, user_id) -> Dict:
@@ -21,6 +26,44 @@ class UserManager(abc.ABC):
     def list_all_users(self):
         pass
 
+    @staticmethod
+    def get_default_profile(**kwargs) -> Dict:
+        default_profile = {
+            "name": "",
+            "nickname": "",
+            "preferred_nickname": "",
+            "age": 0,
+            "region": '',
+            "interests": [],
+            "family_members": {},
+            "health_conditions": [],
+            "medications": [],
+            "reminder_preferences": {
+                "medication": True,
+                "health": True,
+                "weather": True,
+                "news": False
+            },
+            "interaction_style": "standard",  # standard, verbose, concise
+            "interaction_frequency": "medium",  # low, medium, high
+            "last_topics": [],
+            "created_at": int(time.time() * 1000),
+            "human_intervention_history": []
+        }
+        for key, value in kwargs.items():
+            if key in default_profile:
+                default_profile[key] = value
+        return default_profile
+
+class UserRelationManager(abc.ABC):
+    @abc.abstractmethod
+    def list_staffs(self):
+        pass
+
+    @abc.abstractmethod
+    def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
+        pass
+
 class LocalUserManager(UserManager):
     def get_user_profile(self, user_id) -> Dict:
         """加载用户个人资料,如不存在则创建默认资料。主要用于本地调试"""
@@ -29,28 +72,7 @@ class LocalUserManager(UserManager):
                 return json.load(f)
         except FileNotFoundError:
             # 创建默认用户资料
-            default_profile = {
-                "name": "",
-                "nickname": "",
-                "preferred_nickname": "",
-                "age": 0,
-                "region": '',
-                "interests": [],
-                "family_members": {},
-                "health_conditions": [],
-                "medications": [],
-                "reminder_preferences": {
-                    "medication": True,
-                    "health": True,
-                    "weather": True,
-                    "news": False
-                },
-                "interaction_style": "standard",  # standard, verbose, concise
-                "interaction_frequency": "medium",  # low, medium, high
-                "last_topics": [],
-                "created_at": int(time.time() * 1000),
-                "human_intervention_history": []
-            }
+            default_profile = self.get_default_profile()
             self.save_user_profile(user_id, default_profile)
             return default_profile
 
@@ -61,9 +83,45 @@ class LocalUserManager(UserManager):
             json.dump(profile, f, ensure_ascii=False, indent=2)
 
     def list_all_users(self):
-        json_files = []
+        user_ids = []
         for root, dirs, files in os.walk('user_profiles/'):
             for file in files:
                 if file.endswith('.json'):
-                    json_files.append(os.path.splitext(file)[0])
-        return json_files
+                    user_ids.append(os.path.splitext(file)[0])
+        return user_ids
+
+
+class MySQLUserManager(UserManager):
+    def __init__(self, db_config, table_name):
+        self.db = MySQLManager(db_config)
+        self.table_name = table_name
+
+    def get_user_profile(self, user_id) -> Dict:
+        sql = f"SELECT name, wxid, profile_data_v1 FROM {self.table_name} WHERE third_party_user_id = {user_id}"
+        data = self.db.select(sql, pymysql.cursors.DictCursor)
+        if not data:
+            logging.error(f"user[{user_id}] not found")
+            return {}
+        data = data[0]
+        if not data['profile_data_v1']:
+            logging.warning(f"user[{user_id}] profile not found, create a default one")
+            default_profile = self.get_default_profile(nickname=data['name'])
+            self.save_user_profile(user_id, default_profile)
+        return json.loads(data['profile_data_v1'])
+
+    def save_user_profile(self, user_id, profile: Dict) -> None:
+        if not user_id:
+            raise Exception("Invalid user_id: {}".format(user_id))
+        sql = f"UPDATE {self.table_name} SET profile_data_v1 = %s WHERE third_party_user_id = {user_id}"
+        self.db.execute(sql, (json.dumps(profile),))
+
+    def list_all_users(self):
+        sql = f"SELECT third_party_user_id FROM {self.table_name}"
+        data = self.db.select(sql, pymysql.cursors.DictCursor)
+        return [user['third_party_user_id'] for user in data]
+
+if __name__ == '__main__':
+    db_config = configs.get()['storage']['user']
+    user_manager = MySQLUserManager(db_config['mysql'], db_config['table'])
+    user_profile = user_manager.get_user_profile('7881301263964433')
+    print(user_profile)