瀏覽代碼

Update user_manager: add get_user_tags

StrayWarrior 2 周之前
父節點
當前提交
2ab11d99c1
共有 1 個文件被更改,包括 27 次插入3 次删除
  1. 27 3
      user_manager.py

+ 27 - 3
user_manager.py

@@ -75,6 +75,10 @@ class UserRelationManager(abc.ABC):
     def list_staff_users(self) -> List[Dict]:
         pass
 
+    @abc.abstractmethod
+    def get_user_tags(self, user_id: str) -> List[str]:
+        pass
+
 class LocalUserManager(UserManager):
     def get_user_profile(self, user_id) -> Dict:
         """加载用户个人资料,如不存在则创建默认资料。主要用于本地调试"""
@@ -105,7 +109,6 @@ class LocalUserManager(UserManager):
         return {}
 
 
-
 class MySQLUserManager(UserManager):
     def __init__(self, db_config, table_name, staff_table):
         self.db = MySQLManager(db_config)
@@ -187,6 +190,8 @@ class MySQLUserRelationManager(UserRelationManager):
         return []
 
     def list_staff_users(self):
+        # FIXME(zhoutian)
+        # 测试期间逻辑,只取一个账号
         sql = (f"SELECT third_party_user_id, wxid FROM {self.agent_staff_table} WHERE status = 1"
                f" AND third_party_user_id = '1688854492669990'")
         agent_staff_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
@@ -236,6 +241,24 @@ class MySQLUserRelationManager(UserRelationManager):
             ret.extend(staff_user_pairs)
         return ret
 
+    def get_user_tags(self, user_id: str) -> List[str]:
+        sql = f"SELECT wxid FROM {self.agent_user_table} WHERE third_party_user_id = '{user_id}' AND wxid is not null"
+        user_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
+        if not user_data:
+            logger.error(f"user[{user_id}] has no wxid")
+            return []
+        user_wxid = user_data[0]['wxid']
+        sql = f"""
+            select b.tag_id, c.`tag_name`  from `we_com_user` as a
+              join `we_com_user_with_tag` as b
+              join `we_com_tag` as c
+              on a.`id` = b.`user_id`
+              and b.`tag_id` = c.id
+              where a.union_id = '{user_wxid}' """
+        tag_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
+        tag_names = [tag['tag_name'] for tag in tag_data]
+        return tag_names
+
 
 if __name__ == '__main__':
     config = configs.get()
@@ -254,5 +277,6 @@ if __name__ == '__main__':
         wecom_db_config['table']['relation'],
         wecom_db_config['table']['user']
     )
-    all_staff_users = user_relation_manager.list_staff_users()
-    print(all_staff_users)
+    # all_staff_users = user_relation_manager.list_staff_users()
+    user_tags = user_relation_manager.get_user_tags('7881302078008656')
+    print(user_tags)