Browse Source

Add staff profile

StrayWarrior 3 months ago
parent
commit
96a0bf646c
3 changed files with 32 additions and 4 deletions
  1. 2 1
      agent_service.py
  2. 5 1
      dialogue_manager.py
  3. 25 2
      user_manager.py

+ 2 - 1
agent_service.py

@@ -232,10 +232,11 @@ if __name__ == "__main__":
     # 初始化用户管理服务
     # FIXME(zhoutian): 如果不使用MySQL,此数据库配置非必须
     user_db_config = config['storage']['user']
+    staff_db_config = config['storage']['staff']
     if config['debug_flags'].get('use_local_user_storage', False):
         user_manager = LocalUserManager()
     else:
-        user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'])
+        user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
 
     wecom_db_config = config['storage']['user_relation']
     user_relation_manager = MySQLUserRelationManager(

+ 5 - 1
dialogue_manager.py

@@ -98,6 +98,7 @@ class DialogueManager:
         # 目前实际仅用作调试,拼装prompt时使用history_dialogue_service获取
         self.dialogue_history = []
         self.user_profile = self.user_manager.get_user_profile(user_id)
+        self.staff_profile = self.user_manager.get_staff_profile(staff_id)
         self.last_interaction_time = 0
         self.consecutive_clarifications = 0
         self.complex_request_counter = 0
@@ -385,6 +386,8 @@ class DialogueManager:
         time_context = self.get_current_time_context()
         # 刷新用户画像
         self.user_profile = self.user_manager.get_user_profile(self.user_id)
+        # 刷新员工画像(不一定需要)
+        self.staff_profile = self.user_manager.get_staff_profile(self.staff_id)
 
         context = {
             "user_profile": self.user_profile,
@@ -395,7 +398,8 @@ class DialogueManager:
             "last_interaction_interval": self._get_hours_since_last_interaction(2),
             "if_first_interaction": False,
             "if_active_greeting": False if user_message else True,
-            **self.user_profile
+            **self.user_profile,
+            **self.staff_profile
         }
 
         # 获取长期记忆

+ 25 - 2
user_manager.py

@@ -26,6 +26,11 @@ class UserManager(abc.ABC):
     def list_all_users(self):
         pass
 
+    @abc.abstractmethod
+    def get_staff_profile(self, staff_id) -> Dict:
+        #FIXME(zhoutian): 重新设计用户和员工数据管理模型
+        pass
+
     @staticmethod
     def get_default_profile(**kwargs) -> Dict:
         default_profile = {
@@ -94,11 +99,16 @@ class LocalUserManager(UserManager):
                     user_ids.append(os.path.splitext(file)[0])
         return user_ids
 
+    def get_staff_profile(self, staff_id) -> Dict:
+        return {}
+
+
 
 class MySQLUserManager(UserManager):
-    def __init__(self, db_config, table_name):
+    def __init__(self, db_config, table_name, staff_table):
         self.db = MySQLManager(db_config)
         self.table_name = table_name
+        self.staff_table = staff_table
 
     def get_user_profile(self, user_id) -> Dict:
         sql = f"SELECT name, wxid, profile_data_v1 FROM {self.table_name} WHERE third_party_user_id = {user_id}"
@@ -126,6 +136,18 @@ class MySQLUserManager(UserManager):
         data = self.db.select(sql, pymysql.cursors.DictCursor)
         return [user['third_party_user_id'] for user in data]
 
+    def get_staff_profile(self, staff_id) -> Dict:
+        if not self.staff_table:
+            raise Exception("staff_table is not set")
+        sql = f"SELECT agent_name, agent_age, agent_region " \
+              f"FROM {self.staff_table} WHERE third_party_user_id = '{staff_id}'"
+        data = self.db.select(sql, pymysql.cursors.DictCursor)
+        if not data:
+            logging.error(f"staff[{staff_id}] not found")
+            return {}
+        profile = data[0]
+        return profile
+
 
 class MySQLUserRelationManager(UserRelationManager):
     def __init__(self, agent_db_config, wecom_db_config,
@@ -199,7 +221,8 @@ class MySQLUserRelationManager(UserRelationManager):
 if __name__ == '__main__':
     config = configs.get()
     user_db_config = config['storage']['user']
-    user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'])
+    staff_db_config = config['storage']['staff']
+    user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
     user_profile = user_manager.get_user_profile('7881301263964433')
     print(user_profile)