user_manager.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. import logging
  5. from typing import Dict, Optional, Tuple, Any
  6. import json
  7. import time
  8. import os
  9. import abc
  10. import pymysql.cursors
  11. import configs
  12. from database import MySQLManager
  13. class UserManager(abc.ABC):
  14. @abc.abstractmethod
  15. def get_user_profile(self, user_id) -> Dict:
  16. pass
  17. @abc.abstractmethod
  18. def save_user_profile(self, user_id, profile: Dict) -> None:
  19. pass
  20. @abc.abstractmethod
  21. def list_all_users(self):
  22. pass
  23. @staticmethod
  24. def get_default_profile(**kwargs) -> Dict:
  25. default_profile = {
  26. "name": "",
  27. "nickname": "",
  28. "preferred_nickname": "",
  29. "age": 0,
  30. "region": '',
  31. "interests": [],
  32. "family_members": {},
  33. "health_conditions": [],
  34. "medications": [],
  35. "reminder_preferences": {
  36. "medication": True,
  37. "health": True,
  38. "weather": True,
  39. "news": False
  40. },
  41. "interaction_style": "standard", # standard, verbose, concise
  42. "interaction_frequency": "medium", # low, medium, high
  43. "last_topics": [],
  44. "created_at": int(time.time() * 1000),
  45. "human_intervention_history": []
  46. }
  47. for key, value in kwargs.items():
  48. if key in default_profile:
  49. default_profile[key] = value
  50. return default_profile
  51. class UserRelationManager(abc.ABC):
  52. @abc.abstractmethod
  53. def list_staffs(self):
  54. pass
  55. @abc.abstractmethod
  56. def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
  57. pass
  58. class LocalUserManager(UserManager):
  59. def get_user_profile(self, user_id) -> Dict:
  60. """加载用户个人资料,如不存在则创建默认资料。主要用于本地调试"""
  61. try:
  62. with open(f"user_profiles/{user_id}.json", "r", encoding="utf-8") as f:
  63. return json.load(f)
  64. except FileNotFoundError:
  65. # 创建默认用户资料
  66. default_profile = self.get_default_profile()
  67. self.save_user_profile(user_id, default_profile)
  68. return default_profile
  69. def save_user_profile(self, user_id, profile: Dict) -> None:
  70. if not user_id:
  71. raise Exception("Invalid user_id: {}".format(user_id))
  72. with open(f"user_profiles/{user_id}.json", "w", encoding="utf-8") as f:
  73. json.dump(profile, f, ensure_ascii=False, indent=2)
  74. def list_all_users(self):
  75. user_ids = []
  76. for root, dirs, files in os.walk('user_profiles/'):
  77. for file in files:
  78. if file.endswith('.json'):
  79. user_ids.append(os.path.splitext(file)[0])
  80. return user_ids
  81. class MySQLUserManager(UserManager):
  82. def __init__(self, db_config, table_name):
  83. self.db = MySQLManager(db_config)
  84. self.table_name = table_name
  85. def get_user_profile(self, user_id) -> Dict:
  86. sql = f"SELECT name, wxid, profile_data_v1 FROM {self.table_name} WHERE third_party_user_id = {user_id}"
  87. data = self.db.select(sql, pymysql.cursors.DictCursor)
  88. if not data:
  89. logging.error(f"user[{user_id}] not found")
  90. return {}
  91. data = data[0]
  92. if not data['profile_data_v1']:
  93. logging.warning(f"user[{user_id}] profile not found, create a default one")
  94. default_profile = self.get_default_profile(nickname=data['name'])
  95. self.save_user_profile(user_id, default_profile)
  96. return json.loads(data['profile_data_v1'])
  97. def save_user_profile(self, user_id, profile: Dict) -> None:
  98. if not user_id:
  99. raise Exception("Invalid user_id: {}".format(user_id))
  100. sql = f"UPDATE {self.table_name} SET profile_data_v1 = %s WHERE third_party_user_id = {user_id}"
  101. self.db.execute(sql, (json.dumps(profile),))
  102. def list_all_users(self):
  103. sql = f"SELECT third_party_user_id FROM {self.table_name}"
  104. data = self.db.select(sql, pymysql.cursors.DictCursor)
  105. return [user['third_party_user_id'] for user in data]
  106. if __name__ == '__main__':
  107. db_config = configs.get()['storage']['user']
  108. user_manager = MySQLUserManager(db_config['mysql'], db_config['table'])
  109. user_profile = user_manager.get_user_profile('7881301263964433')
  110. print(user_profile)