浏览代码

Update user_manager: use user relation and tag data in agent database

StrayWarrior 2 天之前
父节点
当前提交
58bf5e34a0
共有 1 个文件被更改,包括 47 次插入2 次删除
  1. 47 2
      pqai_agent/user_manager.py

+ 47 - 2
pqai_agent/user_manager.py

@@ -1,6 +1,7 @@
 #! /usr/bin/env python
 # -*- coding: utf-8 -*-
 # vim:fenc=utf-8
+from abc import abstractmethod
 
 from pqai_agent.logging_service import logger
 from typing import Dict, Optional, List
@@ -33,6 +34,10 @@ class UserManager(abc.ABC):
         #FIXME(zhoutian): 重新设计用户和员工数据管理模型
         pass
 
+    @abstractmethod
+    def get_user_tags(self, user_ids: List[str], batch_size = 500) -> Dict[str, List[str]]:
+        pass
+
     @staticmethod
     def get_default_profile(**kwargs) -> Dict:
         default_profile = {
@@ -133,6 +138,9 @@ class LocalUserManager(UserManager):
             logger.error("staff profile not found: {}".format(e))
             return {}
 
+    def get_user_tags(self, user_ids: List[str], batch_size = 500) -> Dict[str, List[str]]:
+        return {}
+
     def list_users(self, **kwargs) -> List[Dict]:
         pass
 
@@ -249,6 +257,34 @@ class MySQLUserManager(UserManager):
         sql = f"UPDATE {self.staff_table} SET agent_profile = %s WHERE third_party_user_id = '{staff_id}'"
         self.db.execute(sql, (json.dumps(profile),))
 
+    def get_user_tags(self, user_ids: List[str], batch_size = 500) -> Dict[str, List[str]]:
+        """
+        获取用户的标签列表
+        :param user_ids: 用户ID
+        :param batch_size: 批量查询的大小
+        :return: 标签名称列表
+        """
+        batches = (len(user_ids) + batch_size - 1) // batch_size
+        ret = {}
+        for i in range(batches):
+            idx_begin = i * batch_size
+            idx_end = min((i + 1) * batch_size, len(user_ids))
+            batch_user_ids = user_ids[idx_begin:idx_end]
+            sql = f"""
+                SELECT a.third_party_user_id, b.tag_id, b.name FROM qywx_user_tag a
+                    JOIN qywx_tag b ON a.tag_id = b.tag_id
+                    AND a.third_party_user_id IN {str(tuple(batch_user_ids))}
+                """
+            rows = self.db.select(sql, pymysql.cursors.DictCursor)
+            # group by user_id
+            for row in rows:
+                user_id = row['third_party_user_id']
+                tag_name = row['name']
+                if user_id not in ret:
+                    ret[user_id] = []
+                ret[user_id].append(tag_name)
+        return ret
+
     def list_users(self, **kwargs) -> List[Dict]:
         user_union_id = kwargs.get('user_union_id', None)
         user_name = kwargs.get('user_name', None)
@@ -333,6 +369,7 @@ class MySQLUserRelationManager(UserRelationManager):
         self.relation_table = relation_table
         self.agent_user_table = agent_user_table
         self.user_table = user_table
+        self.agent_user_relation_table = 'qywx_employee_customer'
 
     def list_staffs(self):
         sql = f"SELECT third_party_user_id, name, wxid, agent_name FROM {self.agent_staff_table} WHERE status = 1"
@@ -342,7 +379,14 @@ class MySQLUserRelationManager(UserRelationManager):
     def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
         return []
 
-    def list_staff_users(self, staff_id: str = None, tag_id: int = None):
+    def list_staff_users(self, staff_id: str = None, tag_id: int = None) -> List[Dict]:
+        sql = f"SELECT employee_id as staff_id, customer_id as user_id FROM {self.agent_user_relation_table} WHERE 1 = 1"
+        if staff_id:
+            sql += f" AND employee_id = '{staff_id}'"
+        agent_staff_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
+        return agent_staff_data
+
+    def list_staff_users_v1(self, staff_id: str = None, tag_id: int = None):
         sql = f"SELECT third_party_user_id, wxid FROM {self.agent_staff_table} WHERE status = 1"
         if staff_id:
             sql += f" AND third_party_user_id = '{staff_id}'"
@@ -382,7 +426,8 @@ class MySQLUserRelationManager(UserRelationManager):
                 sql = f"SELECT third_party_user_id, wxid FROM {self.agent_user_table} WHERE wxid IN {str(batch_union_ids)}"
                 batch_agent_user_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
                 if len(agent_user_data) != len(batch_union_ids):
-                    # logger.debug(f"staff[{wxid}] some users not found in agent database")
+                    diff_num = len(batch_union_ids) - len(batch_agent_user_data)
+                    logger.debug(f"staff[{staff_id}] {diff_num} users not found in agent database")
                     pass
                 agent_user_data.extend(batch_agent_user_data)
             staff_user_pairs = [