Browse Source

Revert "develop-0520"

This reverts commit 7b327448
luojunhui 1 month ago
parent
commit
a6bddb2be8
3 changed files with 120 additions and 217 deletions
  1. 108 185
      pqai_agent/user_manager.py
  2. 11 32
      pqai_agent_server/api_server.py
  3. 1 0
      requirements.txt

+ 108 - 185
pqai_agent/user_manager.py

@@ -30,7 +30,7 @@ class UserManager(abc.ABC):
 
     @abc.abstractmethod
     def get_staff_profile(self, staff_id) -> Dict:
-        # FIXME(zhoutian): 重新设计用户和员工数据管理模型
+        #FIXME(zhoutian): 重新设计用户和员工数据管理模型
         pass
 
     @staticmethod
@@ -42,7 +42,7 @@ class UserManager(abc.ABC):
             "preferred_nickname": "",
             "gender": "未知",
             "age": 0,
-            "region": "",
+            "region": '',
             "interests": [],
             "family_members": {},
             "health_conditions": [],
@@ -51,13 +51,13 @@ class UserManager(abc.ABC):
                 "medication": True,
                 "health": True,
                 "weather": True,
-                "news": False,
+                "news": False
             },
             "interaction_style": "standard",  # standard, verbose, concise
             "interaction_frequency": "medium",  # low, medium, high
             "last_topics": [],
             "created_at": int(time.time() * 1000),
-            "human_intervention_history": [],
+            "human_intervention_history": []
         }
         for key, value in kwargs.items():
             if key in default_profile:
@@ -67,7 +67,6 @@ class UserManager(abc.ABC):
     def list_users(self, **kwargs) -> List[Dict]:
         pass
 
-
 class UserRelationManager(abc.ABC):
     @abc.abstractmethod
     def list_staffs(self):
@@ -89,7 +88,6 @@ class UserRelationManager(abc.ABC):
     def stop_user_daily_push(self, user_id: str) -> bool:
         pass
 
-
 class LocalUserManager(UserManager):
     def get_user_profile(self, user_id) -> Dict:
         """加载用户个人资料,如不存在则创建默认资料。主要用于本地调试"""
@@ -100,9 +98,7 @@ class LocalUserManager(UserManager):
             entry_added = False
             for key, value in default_profile.items():
                 if key not in profile:
-                    logger.debug(
-                        f"user[{user_id}] add profile key[{key}] value[{value}]"
-                    )
+                    logger.debug(f"user[{user_id}] add profile key[{key}] value[{value}]")
                     profile[key] = value
                     entry_added = True
             if entry_added:
@@ -121,9 +117,9 @@ class LocalUserManager(UserManager):
 
     def list_all_users(self):
         user_ids = []
-        for root, dirs, files in os.walk("../user_profiles/"):
+        for root, dirs, files in os.walk('../user_profiles/'):
             for file in files:
-                if file.endswith(".json"):
+                if file.endswith('.json'):
                     user_ids.append(os.path.splitext(file)[0])
         return user_ids
 
@@ -142,11 +138,8 @@ class LocalUserManager(UserManager):
     def list_users(self, **kwargs) -> List[Dict]:
         pass
 
-
 class MySQLUserManager(UserManager):
-    PROFILE_EXCLUDE_ITEMS = [
-        "avatar",
-    ]
+    PROFILE_EXCLUDE_ITEMS = ['avatar', ]
 
     def __init__(self, db_config, table_name, staff_table):
         self.db = MySQLManager(db_config)
@@ -154,26 +147,22 @@ class MySQLUserManager(UserManager):
         self.staff_table = staff_table
 
     def get_user_profile(self, user_id) -> Dict:
-        sql = (
-            f"SELECT name, wxid, profile_data_v1, gender, iconurl as avatar"
-            f" FROM {self.table_name} WHERE third_party_user_id = {user_id}"
-        )
+        sql = f"SELECT name, wxid, profile_data_v1, gender, iconurl as avatar" \
+              f" FROM {self.table_name} WHERE third_party_user_id = {user_id}"
         data = self.db.select(sql, pymysql.cursors.DictCursor)
         if not data:
             logger.error(f"user[{user_id}] not found")
             return {}
         data = data[0]
-        gender_map = {0: "未知", 1: "男", 2: "女", None: "未知"}
-        gender = gender_map[data["gender"]]
-        default_profile = self.get_default_profile(
-            nickname=data["name"], gender=gender, avatar=data["avatar"]
-        )
-        if not data["profile_data_v1"]:
+        gender_map = {0: '未知', 1: '男', 2: '女', None: '未知'}
+        gender = gender_map[data['gender']]
+        default_profile = self.get_default_profile(nickname=data['name'], gender=gender, avatar=data['avatar'])
+        if not data['profile_data_v1']:
             logger.warning(f"user[{user_id}] profile not found, create a default one")
             self.save_user_profile(user_id, default_profile)
             return default_profile
         else:
-            profile = json.loads(data["profile_data_v1"])
+            profile = json.loads(data['profile_data_v1'])
             # 资料条目有增加时,需合并更新
             entry_added = False
             for key, value in default_profile.items():
@@ -188,7 +177,7 @@ class MySQLUserManager(UserManager):
     def save_user_profile(self, user_id, profile: Dict) -> None:
         if not user_id:
             raise Exception("Invalid user_id: {}".format(user_id))
-        if configs.get().get("debug_flags", {}).get("disable_database_write", False):
+        if configs.get().get('debug_flags', {}).get('disable_database_write', False):
             return
         profile = profile.copy()
         for name in self.PROFILE_EXCLUDE_ITEMS:
@@ -199,7 +188,7 @@ class MySQLUserManager(UserManager):
     def list_all_users(self):
         sql = f"SELECT third_party_user_id FROM {self.table_name}"
         data = self.db.select(sql, pymysql.cursors.DictCursor)
-        return [user["third_party_user_id"] for user in data]
+        return [user['third_party_user_id'] for user in data]
 
     def get_staff_profile(self, staff_id) -> Dict:
         if not self.staff_table:
@@ -207,48 +196,42 @@ class MySQLUserManager(UserManager):
         return self.get_staff_profile_v3(staff_id)
 
     def get_staff_profile_v1(self, staff_id) -> Dict:
-        sql = (
-            f"SELECT agent_name, agent_gender, agent_age, agent_region, agent_profile "
-            f"FROM {self.staff_table} WHERE third_party_user_id = '{staff_id}'"
-        )
+        sql = f"SELECT agent_name, agent_gender, agent_age, agent_region, agent_profile " \
+              f"FROM {self.staff_table} WHERE third_party_user_id = '{staff_id}'"
         data = self.db.select(sql, pymysql.cursors.DictCursor)
         if not data:
             logger.error(f"staff[{staff_id}] not found")
             return {}
         profile = data[0]
         # 转换性别格式
-        gender_map = {0: "未知", 1: "男", 2: "女", None: "未知"}
-        profile["agent_gender"] = gender_map[profile["agent_gender"]]
+        gender_map = {0: '未知', 1: '男', 2: '女', None: '未知'}
+        profile['agent_gender'] = gender_map[profile['agent_gender']]
         return profile
 
     def get_staff_profile_v2(self, staff_id) -> Dict:
-        sql = (
-            f"SELECT agent_name as name, agent_gender as gender, agent_age as age, agent_region as region, agent_profile "
-            f"FROM {self.staff_table} WHERE third_party_user_id = '{staff_id}'"
-        )
+        sql = f"SELECT agent_name as name, agent_gender as gender, agent_age as age, agent_region as region, agent_profile " \
+              f"FROM {self.staff_table} WHERE third_party_user_id = '{staff_id}'"
         data = self.db.select(sql, pymysql.cursors.DictCursor)
         if not data:
             logger.error(f"staff[{staff_id}] not found")
             return {}
         profile = data[0]
         # 转换性别格式
-        gender_map = {0: "未知", 1: "男", 2: "女", None: "未知"}
-        profile["gender"] = gender_map[profile["gender"]]
+        gender_map = {0: '未知', 1: '男', 2: '女', None: '未知'}
+        profile['gender'] = gender_map[profile['gender']]
 
         # 合并JSON字段(新版本)数据
-        if profile["agent_profile"]:
-            detail_profile = json.loads(profile["agent_profile"])
+        if profile['agent_profile']:
+            detail_profile = json.loads(profile['agent_profile'])
             profile.update(detail_profile)
 
         # 去除原始字段
-        profile.pop("agent_profile", None)
+        profile.pop('agent_profile', None)
         return profile
 
     def get_staff_profile_v3(self, staff_id) -> Dict:
-        sql = (
-            f"SELECT agent_profile "
-            f"FROM {self.staff_table} WHERE third_party_user_id = '{staff_id}'"
-        )
+        sql = f"SELECT agent_profile " \
+              f"FROM {self.staff_table} WHERE third_party_user_id = '{staff_id}'"
         data = self.db.select(sql)
         if not data:
             logger.error(f"staff[{staff_id}] not found")
@@ -269,8 +252,8 @@ class MySQLUserManager(UserManager):
         self.db.execute(sql, (json.dumps(profile),))
 
     def list_users(self, **kwargs) -> List[Dict]:
-        user_union_id = kwargs.get("user_union_id", None)
-        user_name = kwargs.get("user_name", None)
+        user_union_id = kwargs.get('user_union_id', None)
+        user_name = kwargs.get('user_name', None)
         if not user_union_id and not user_name:
             raise Exception("user_union_id or user_name is required")
         sql = f"SELECT third_party_user_id, wxid, name, iconurl, gender FROM {self.table_name} WHERE 1=1 "
@@ -281,13 +264,7 @@ class MySQLUserManager(UserManager):
         data = self.db.select(sql, pymysql.cursors.DictCursor)
         return data
 
-    def get_staff_sessions(
-        self,
-        staff_id,
-        page_id: int = 1,
-        page_size: int = 10,
-        session_type: str = "default",
-    ) -> List[Dict]:
+    def get_staff_sessions(self, staff_id, page_id: int = 1, page_size: int = 10, session_type: str = 'default') -> List[Dict]:
         """
         :param page_size:
         :param page_id:
@@ -296,14 +273,14 @@ class MySQLUserManager(UserManager):
         :return:
         """
         match session_type:
-            case "active":
+            case 'active':
                 sql = f"""
                     select staff_id, current_state, user_id
                     from agent_state
                     where staff_id = %s and update_timestamp >= DATE_SUB(NOW(), INTERVAL 2 HOUR)
                     order by update_timestamp desc;
                 """
-            case "human_intervention":
+            case 'human_intervention':
                 sql = f"""
                     select staff_id, current_state, user_id
                     from agent_state
@@ -321,14 +298,10 @@ class MySQLUserManager(UserManager):
                     limit {page_size + 1} offset {page_size * (page_id - 1)};
                 """
 
-        staff_sessions = self.db.select(
-            sql, cursor_type=pymysql.cursors.DictCursor, args=(staff_id,)
-        )
+        staff_sessions = self.db.select(sql, cursor_type=pymysql.cursors.DictCursor, args=(staff_id, ))
         return staff_sessions
 
-    def get_staff_sessions_summary_v1(
-        self, staff_id, page_id: int, page_size: int, status: int
-    ) -> Dict:
+    def get_staff_sessions_summary_v1(self, staff_id, page_id: int, page_size: int, status: int) -> Dict:
         """
         :param status: staff status(0: unemployed, 1: employed)
         :param staff_id: staff
@@ -345,7 +318,7 @@ class MySQLUserManager(UserManager):
             staff_id_list = self.db.select(
                 sql=get_staff_query,
                 cursor_type=pymysql.cursors.DictCursor,
-                args=(status, page_size + 1, (page_id - 1) * page_size),
+                args=(status, page_size + 1, (page_id - 1) * page_size)
             )
             if not staff_id_list:
                 return {}
@@ -365,7 +338,7 @@ class MySQLUserManager(UserManager):
             staff_id_list = self.db.select(
                 sql=get_staff_query,
                 cursor_type=pymysql.cursors.DictCursor,
-                args=(status, staff_id),
+                args=(status, staff_id)
             )
             if not staff_id_list:
                 return {}
@@ -373,25 +346,17 @@ class MySQLUserManager(UserManager):
             next_page_id = None
         response_data = [
             {
-                "staff_id": staff["third_party_user_id"],
-                "staff_name": staff["name"],
-                "active_sessions": len(
-                    self.get_staff_sessions(
-                        staff["third_party_user_id"], session_type="active"
-                    )
-                ),
-                "human_intervention_sessions": len(
-                    self.get_staff_sessions(
-                        staff["third_party_user_id"], session_type="human_intervention"
-                    )
-                ),
+                'staff_id': staff['third_party_user_id'],
+                'staff_name': staff['name'],
+                'active_sessions': len(self.get_staff_sessions(staff['third_party_user_id'], session_type='active')),
+                'human_intervention_sessions': len(self.get_staff_sessions(staff['third_party_user_id'], session_type='human_intervention'))
             }
             for staff in staff_id_list
         ]
         return {
-            "has_next_page": has_next_page,
-            "next_page_id": next_page_id,
-            "data": response_data,
+            'has_next_page': has_next_page,
+            'next_page_id': next_page_id,
+            'data': response_data
         }
 
     def get_staff_session_list_v1(self, staff_id, page_id: int, page_size: int) -> Dict:
@@ -412,29 +377,29 @@ class MySQLUserManager(UserManager):
         response_data = []
         for session in session_list:
             temp_obj = {}
-            user_id = session["user_id"]
-            room_id = ":".join(["private", staff_id, user_id])
+            user_id = session['user_id']
+            room_id = ':'.join(['private', staff_id, user_id])
             select_query = f"""select content, max(sendtime) as max_timestamp from qywx_chat_history where roomid = %s;"""
             last_message = self.db.select(
                 sql=select_query,
                 cursor_type=pymysql.cursors.DictCursor,
-                args=(room_id,),
+                args=(room_id,)
             )
             if not last_message:
-                temp_obj["message"] = ""
-                temp_obj["timestamp"] = 0
+                temp_obj['message'] = ''
+                temp_obj['timestamp'] = 0
             else:
-                temp_obj["message"] = last_message[0]["content"]
-                temp_obj["timestamp"] = last_message[0]["max_timestamp"]
-            temp_obj["customer_id"] = user_id
-            temp_obj["customer_name"] = session["name"]
-            temp_obj["avatar"] = session["iconurl"]
+                temp_obj['message'] = last_message[0]['content']
+                temp_obj['timestamp'] = last_message[0]['max_timestamp']
+            temp_obj['customer_id'] = user_id
+            temp_obj['customer_name'] = session['name']
+            temp_obj['avatar'] = session['iconurl']
             response_data.append(temp_obj)
         return {
             "staff_id": staff_id,
             "has_next_page": has_next_page,
             "next_page_id": next_page_id,
-            "data": response_data,
+            "data": response_data
         }
 
     def get_staff_list(self, page_id: int, page_size: int) -> Dict:
@@ -452,7 +417,7 @@ class MySQLUserManager(UserManager):
         staff_list = self.db.select(
             sql=sql,
             cursor_type=pymysql.cursors.DictCursor,
-            args=(page_size + 1, page_size * (page_id - 1)),
+            args=(page_size + 1, page_size * (page_id - 1))
         )
         if len(staff_list) > page_size:
             has_next_page = True
@@ -464,23 +429,21 @@ class MySQLUserManager(UserManager):
         return {
             "has_next_page": has_next_page,
             "next_page": next_page_id,
-            "data": staff_list,
+            "data": staff_list
         }
 
-    def get_conversation_list_v1(
-        self, staff_id: str, customer_id: str, page: Optional[int], page_size: int
-    ):
+    def get_conversation_list_v1(self, staff_id: str, customer_id: str, page: Optional[int]):
         """
         :param staff_id:
         :param customer_id:
         :param page: timestamp
-        :param page_size:
         :return:
         """
-        room_id = ":".join(["private", staff_id, customer_id])
+        room_id = ':'.join(['private', staff_id, customer_id])
+        page_size = 20
         if not page:
             fetch_query = f"""
-                select t1.sender, t2.name, t1.sendtime, t1.content, t2.iconurl, t1.msg_type
+                select t1.sender, t2.name, t1.sendtime, t1.content, t2.iconurl
                 from qywx_chat_history t1
                 join third_party_user t2 on t1.sender = t2.third_party_user_id
                 where roomid = %s
@@ -490,11 +453,11 @@ class MySQLUserManager(UserManager):
             messages = self.db.select(
                 sql=fetch_query,
                 cursor_type=pymysql.cursors.DictCursor,
-                args=(room_id, page_size + 1),
+                args=(room_id, page_size + 1)
             )
         else:
             fetch_query = f"""
-                select t1.sender, t2.name, t1.sendtime, t1.content, t2.iconurl, t1.msg_type
+                select t1.sender, t2.name, t1.sendtime, t1.content, t2.iconurl
                 from qywx_chat_history t1
                 join third_party_user t2 on t1.sender = t2.third_party_user_id
                 where t1.roomid = %s and t1.sendtime <= %s
@@ -504,33 +467,32 @@ class MySQLUserManager(UserManager):
             messages = self.db.select(
                 sql=fetch_query,
                 cursor_type=pymysql.cursors.DictCursor,
-                args=(room_id, page, page_size + 1),
+                args=(room_id, page, page_size + 1)
             )
         if messages:
             if len(messages) > page_size:
                 has_next_page = True
-                next_page = messages[-1]["sendtime"]
+                next_page = messages[-1]['sendtime']
             else:
                 has_next_page = False
                 next_page = None
             response_data = [
                 {
-                    "sender_id": message["sender"],
-                    "sender_name": message["name"],
-                    "avatar": message["iconurl"],
-                    "content": message["content"],
-                    "timestamp": message["sendtime"],
-                    "role": "customer" if message["sender"] == customer_id else "staff",
-                    "message_type": message["msg_type"],
+                    "sender_id": message['sender'],
+                    "sender_name": message['name'],
+                    "avatar": message['iconurl'],
+                    "content": message['content'],
+                    "timestamp": message['sendtime'],
+                    "role": "customer" if message['sender'] == customer_id else "staff"
                 }
-                for message in messages[ :page_size]
+                for message in messages
             ]
             return {
                 "staff_id": staff_id,
                 "customer_id": customer_id,
                 "has_next_page": has_next_page,
                 "next_page": next_page,
-                "data": response_data,
+                "data": response_data
             }
         else:
             has_next_page = False
@@ -540,7 +502,7 @@ class MySQLUserManager(UserManager):
                 "customer_id": customer_id,
                 "has_next_page": has_next_page,
                 "next_page": next_page,
-                "data": [],
+                "data": []
             }
 
 
@@ -550,47 +512,21 @@ class LocalUserRelationManager(UserRelationManager):
 
     def list_staffs(self):
         return [
-            {
-                "third_party_user_id": "1688855931724582",
-                "name": "",
-                "wxid": "ShengHuoLeQu",
-                "agent_name": "小芳",
-            }
+            {"third_party_user_id": '1688855931724582', "name": "", "wxid": "ShengHuoLeQu", "agent_name": "小芳"}
         ]
 
     def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
         return []
 
     def list_staff_users(self, staff_id: str = None, tag_id: int = None):
-        user_ids = [
-            "7881299453089278",
-            "7881299453132630",
-            "7881299454186909",
-            "7881299455103430",
-            "7881299455173476",
-            "7881299456216398",
-            "7881299457990953",
-            "7881299461167644",
-            "7881299463002136",
-            "7881299464081604",
-            "7881299465121735",
-            "7881299465998082",
-            "7881299466221881",
-            "7881299467152300",
-            "7881299470051791",
-            "7881299470112816",
-            "7881299471149567",
-            "7881299471168030",
-            "7881299471277650",
-            "7881299473321703",
-        ]
+        user_ids = ['7881299453089278', '7881299453132630', '7881299454186909', '7881299455103430', '7881299455173476',
+                    '7881299456216398', '7881299457990953', '7881299461167644', '7881299463002136', '7881299464081604',
+                    '7881299465121735', '7881299465998082', '7881299466221881', '7881299467152300', '7881299470051791',
+                    '7881299470112816', '7881299471149567', '7881299471168030', '7881299471277650', '7881299473321703']
         user_ids = user_ids[:5]
         return [
             {"staff_id": "1688855931724582", "user_id": "7881299670930896"},
-            *[
-                {"staff_id": "1688855931724582", "user_id": user_id}
-                for user_id in user_ids
-            ],
+            *[{"staff_id": "1688855931724582", "user_id": user_id} for user_id in user_ids]
         ]
 
     def get_user_tags(self, user_id: str):
@@ -599,18 +535,10 @@ class LocalUserRelationManager(UserRelationManager):
     def stop_user_daily_push(self, user_id: str) -> bool:
         return True
 
-
 class MySQLUserRelationManager(UserRelationManager):
-    def __init__(
-        self,
-        agent_db_config,
-        wecom_db_config,
-        agent_staff_table,
-        agent_user_table,
-        staff_table,
-        relation_table,
-        user_table,
-    ):
+    def __init__(self, agent_db_config, wecom_db_config,
+                 agent_staff_table, agent_user_table,
+                 staff_table, relation_table, user_table):
         # FIXME(zhoutian): 因为现在数据库表不统一,需要从两个库读取
         self.agent_db = MySQLManager(agent_db_config)
         self.wecom_db = MySQLManager(wecom_db_config)
@@ -637,19 +565,19 @@ class MySQLUserRelationManager(UserRelationManager):
             return []
         ret = []
         for agent_staff in agent_staff_data:
-            wxid = agent_staff["wxid"]
+            wxid = agent_staff['wxid']
             sql = f"SELECT id FROM {self.staff_table} WHERE carrier_id = '{wxid}'"
             staff_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
             if not staff_data:
                 logger.error(f"staff[{wxid}] not found in wecom database")
                 continue
-            staff_id = staff_data[0]["id"]
+            staff_id = staff_data[0]['id']
             sql = f"SELECT user_id FROM {self.relation_table} WHERE staff_id = '{staff_id}' AND is_delete = 0"
             user_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
             if not user_data:
                 logger.warning(f"staff[{wxid}] has no user")
                 continue
-            user_ids = tuple(user["user_id"] for user in user_data)
+            user_ids = tuple(user['user_id'] for user in user_data)
             sql = f"SELECT union_id FROM {self.user_table} WHERE id IN {str(user_ids)} AND union_id is not null"
             if tag_id:
                 sql += f" AND id in (SELECT distinct user_id FROM we_com_user_with_tag WHERE tag_id = {tag_id} and is_delete = 0)"
@@ -657,7 +585,7 @@ class MySQLUserRelationManager(UserRelationManager):
             if not user_data:
                 logger.warning(f"staff[{wxid}] users not found in wecom database")
                 continue
-            user_union_ids = tuple(user["union_id"] for user in user_data)
+            user_union_ids = tuple(user['union_id'] for user in user_data)
             batch_size = 500
             n_batches = (len(user_union_ids) + batch_size - 1) // batch_size
             agent_user_data = []
@@ -666,17 +594,15 @@ class MySQLUserRelationManager(UserRelationManager):
                 idx_end = min((i + 1) * batch_size, len(user_union_ids))
                 batch_union_ids = user_union_ids[idx_begin:idx_end]
                 sql = f"SELECT third_party_user_id, wxid FROM {self.agent_user_table} WHERE wxid IN {str(batch_union_ids)}"
-                batch_agent_user_data = self.agent_db.select(
-                    sql, pymysql.cursors.DictCursor
-                )
+                batch_agent_user_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
                 if len(agent_user_data) != len(batch_union_ids):
                     # logger.debug(f"staff[{wxid}] some users not found in agent database")
                     pass
                 agent_user_data.extend(batch_agent_user_data)
             staff_user_pairs = [
                 {
-                    "staff_id": agent_staff["third_party_user_id"],
-                    "user_id": agent_user["third_party_user_id"],
+                    'staff_id': agent_staff['third_party_user_id'],
+                    'user_id': agent_user['third_party_user_id']
                 }
                 for agent_user in agent_user_data
             ]
@@ -689,7 +615,7 @@ class MySQLUserRelationManager(UserRelationManager):
         if not user_data:
             logger.error(f"user[{user_id}] has no union id")
             return None
-        union_id = user_data[0]["wxid"]
+        union_id = user_data[0]['wxid']
         return union_id
 
     def get_user_tags(self, user_id: str) -> List[str]:
@@ -704,7 +630,7 @@ class MySQLUserRelationManager(UserRelationManager):
               and b.`tag_id` = c.id
               where a.union_id = '{union_id}' """
         tag_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
-        tag_names = [tag["tag_name"] for tag in tag_data]
+        tag_names = [tag['tag_name'] for tag in tag_data]
         return tag_names
 
     def stop_user_daily_push(self, user_id: str) -> bool:
@@ -713,7 +639,7 @@ class MySQLUserRelationManager(UserRelationManager):
             if not union_id:
                 return False
             sql = f"UPDATE {self.user_table} SET group_msg_disabled = 1 WHERE union_id = %s"
-            rows = self.wecom_db.execute(sql, (union_id,))
+            rows = self.wecom_db.execute(sql, (union_id, ))
             if rows > 0:
                 return True
             else:
@@ -723,26 +649,23 @@ class MySQLUserRelationManager(UserRelationManager):
             return False
 
 
-if __name__ == "__main__":
+if __name__ == '__main__':
     config = configs.get()
-    user_db_config = config["storage"]["user"]
-    staff_db_config = config["storage"]["staff"]
-    user_manager = MySQLUserManager(
-        user_db_config["mysql"], user_db_config["table"], staff_db_config["table"]
-    )
-    user_profile = user_manager.get_user_profile("7881301263964433")
+    user_db_config = config['storage']['user']
+    staff_db_config = config['storage']['staff']
+    user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
+    user_profile = user_manager.get_user_profile('7881301263964433')
     print(user_profile)
 
-    wecom_db_config = config["storage"]["user_relation"]
+    wecom_db_config = config['storage']['user_relation']
     user_relation_manager = MySQLUserRelationManager(
-        user_db_config["mysql"],
-        wecom_db_config["mysql"],
-        config["storage"]["staff"]["table"],
-        user_db_config["table"],
-        wecom_db_config["table"]["staff"],
-        wecom_db_config["table"]["relation"],
-        wecom_db_config["table"]["user"],
+        user_db_config['mysql'], wecom_db_config['mysql'],
+        config['storage']['staff']['table'],
+        user_db_config['table'],
+        wecom_db_config['table']['staff'],
+        wecom_db_config['table']['relation'],
+        wecom_db_config['table']['user']
     )
     # all_staff_users = user_relation_manager.list_staff_users()
-    user_tags = user_relation_manager.get_user_tags("7881302078008656")
+    user_tags = user_relation_manager.get_user_tags('7881302078008656')
     print(user_tags)

+ 11 - 32
pqai_agent_server/api_server.py

@@ -1,7 +1,7 @@
 #! /usr/bin/env python
 # -*- coding: utf-8 -*-
 # vim:fenc=utf-8
-
+import time
 import logging
 import werkzeug.exceptions
 from flask import Flask, request, jsonify
@@ -12,8 +12,7 @@ from pqai_agent import configs
 from pqai_agent import logging_service, chat_service, prompt_templates
 from pqai_agent.history_dialogue_service import HistoryDialogueService
 from pqai_agent.user_manager import MySQLUserManager, MySQLUserRelationManager
-from pqai_agent_server.const import AgentApiConst
-from pqai_agent_server.utils import wrap_response, quit_human_intervention_status
+from pqai_agent_server.utils import wrap_response
 from pqai_agent_server.utils import (
     run_extractor_prompt,
     run_chat_prompt,
@@ -22,7 +21,6 @@ from pqai_agent_server.utils import (
 
 app = Flask("agent_api_server")
 logger = logging_service.logger
-const = AgentApiConst()
 
 
 @app.route("/api/listStaffs", methods=["GET"])
@@ -169,9 +167,9 @@ def health_check():
 @app.route("/api/getStaffSessionSummary", methods=["GET"])
 def get_staff_session_summary():
     staff_id = request.args.get("staff_id")
-    status = request.args.get("status", const.DEFAULT_STAFF_STATUS)
-    page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
-    page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
+    status = request.args.get("status", 1)
+    page_id = request.args.get("page_id", 1)
+    page_size = request.args.get("page_size", 10)
 
     # check params
     try:
@@ -197,11 +195,9 @@ def get_staff_session_list():
     if not staff_id:
         return wrap_response(404, msg="staff_id is required")
 
-    page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
-    page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
-    staff_session_list = app.user_manager.get_staff_session_list_v1(
-        staff_id, page_id, page_size
-    )
+    page_size = request.args.get("page_size", 10)
+    page_id = request.args.get("page_id", 1)
+    staff_session_list = app.user_manager.get_staff_session_list_v1(staff_id, page_id, page_size)
     if not staff_session_list:
         return wrap_response(404, msg="staff not found")
 
@@ -210,8 +206,8 @@ def get_staff_session_list():
 
 @app.route("/api/getStaffList", methods=["GET"])
 def get_staff_list():
-    page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
-    page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
+    page_size = request.args.get("page_size", 10)
+    page_id = request.args.get("page_id", 1)
     staff_list = app.user_manager.get_staff_list(page_id, page_size)
     if not staff_list:
         return wrap_response(404, msg="staff not found")
@@ -230,24 +226,7 @@ def get_conversation_list():
         return wrap_response(404, msg="staff_id and customer_id are required")
 
     page = request.args.get("page")
-    response = app.user_manager.get_conversation_list_v1(staff_id, customer_id, page, const.DEFAULT_CONVERSATION_SIZE)
-    return wrap_response(200, data=response)
-
-
-@app.route("/api/quitHumanInterventionStatus", methods=["GET"])
-def quit_human_interventions_status():
-    """
-    退出人工介入状态
-    :return:
-    """
-    staff_id = request.args.get("staff_id")
-    customer_id = request.args.get("customer_id")
-    # 测试环境: staff_id 强制等于1688854492669990
-    staff_id = 1688854492669990
-    if not customer_id or not staff_id:
-        return wrap_response(404, msg="user_id and staff_id are required")
-    response = quit_human_intervention_status(customer_id, staff_id)
-
+    response = app.user_manager.get_conversation_list_v1(staff_id, customer_id, page)
     return wrap_response(200, data=response)
 
 

+ 1 - 0
requirements.txt

@@ -54,6 +54,7 @@ pyapollos~=0.1.5
 Werkzeug~=3.1.3
 Flask~=3.1.0
 jsonschema~=4.23.0
+pqai_agent~=0.1.0
 numpy~=2.2.5
 pillow~=11.2.1
 json5~=0.12.0