소스 검색

增加数据集相关操作,修改单元测试任务

xueyiming 4 일 전
부모
커밋
d26aa2c638

+ 2 - 2
pqai_agent/data_models/agent_test_task.py

@@ -14,6 +14,6 @@ class AgentTestTask(Base):
     update_user = Column(String(32), nullable=True, comment="更新用户")
     dataset_ids = Column(Text, nullable=False, comment="数据集ids")
     status = Column(Integer, default=0, nullable=False, comment="状态(0:未开始, 1:进行中, 2:已完成, 3:已取消)")
-    create_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", comment="创建时间")
-    update_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", onupdate="CURRENT_TIMESTAMP",
+    create_time = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP", comment="创建时间")
+    update_time = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP", onupdate="CURRENT_TIMESTAMP",
                          comment="更新时间")

+ 2 - 2
pqai_agent/data_models/agent_test_task_conversations.py

@@ -17,6 +17,6 @@ class AgentTestTaskConversations(Base):
     score = Column(Text, nullable=False, comment="得分")
     status = Column(Integer, default=0, nullable=False,
                     comment="状态(0:待执行, 1:执行中, 2:执行成功, 3:执行失败, 4:已取消)")
-    create_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", comment="创建时间")
-    update_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", onupdate="CURRENT_TIMESTAMP",
+    create_time = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP", comment="创建时间")
+    update_time = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP", onupdate="CURRENT_TIMESTAMP",
                          comment="更新时间")

+ 2 - 2
pqai_agent/data_models/dataset_model.py

@@ -12,6 +12,6 @@ class DatasetModule(Base):
     module_id = Column(BigInteger, nullable=False, comment="模型id")
     is_default = Column(Integer, nullable=False, default=0, comment="是否为该模块的默认数据集")
     is_delete = Column(Integer, nullable=False, default=0, comment="是否删除 1-删除 0-未删除")
-    create_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", comment="创建时间")
-    update_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", onupdate="CURRENT_TIMESTAMP",
+    create_time = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP", comment="创建时间")
+    update_time = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP", onupdate="CURRENT_TIMESTAMP",
                          comment="更新时间")

+ 2 - 2
pqai_agent/data_models/datasets.py

@@ -12,7 +12,7 @@ class Datasets(Base):
     type = Column(Integer, default=0, nullable=False, comment="数据集类型 0-内部 1-外部")
     description = Column(String(256), nullable=True, comment="数据集描述")
     is_delete = Column(Integer, nullable=False, default=False, comment="是否删除 1-删除 0-未删除")
-    create_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", comment="创建时间")
-    update_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", onupdate="CURRENT_TIMESTAMP",
+    create_time = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP", comment="创建时间")
+    update_time = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP", onupdate="CURRENT_TIMESTAMP",
                          comment="更新时间")
 

+ 2 - 2
pqai_agent/data_models/internal_conversation_data.py

@@ -18,6 +18,6 @@ class InternalConversationData(Base):
     send_type = Column(Integer, nullable=False, comment="回复类型 0: reply 1: push")
     user_active_rate = Column(Float, nullable=False, comment="用户活跃度")
     is_delete = Column(Integer, nullable=False, default=False, comment="是否删除 1-删除 0-未删除")
-    create_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", comment="创建时间")
-    update_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", onupdate="CURRENT_TIMESTAMP",
+    create_time = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP", comment="创建时间")
+    update_time = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP", onupdate="CURRENT_TIMESTAMP",
                          comment="更新时间")

+ 43 - 3
pqai_agent_server/api_server.py

@@ -22,7 +22,7 @@ from pqai_agent.utils.db_utils import create_sql_engine
 from pqai_agent.utils.prompt_utils import format_agent_profile, format_user_profile
 from pqai_agent_server.const import AgentApiConst
 from pqai_agent_server.const.status_enum import TestTaskStatus
-from pqai_agent_server.dataset_server import DatasetServer
+from pqai_agent_server.dataset_service import DatasetService
 from pqai_agent_server.models import MySQLSessionManager
 from pqai_agent_server.task_server import TaskManager
 from pqai_agent_server.utils import wrap_response, quit_human_intervention_status
@@ -621,6 +621,43 @@ def resume_test_task():
     return wrap_response(200)
 
 
+@app.route("/api/getDatasetList", methods=["GET"])
+def get_dataset_list():
+    """
+       获取单元测试任务列表
+       :return:
+    """
+    page_num = request.args.get("pageNum", const.DEFAULT_PAGE_ID)
+    page_size = request.args.get("pageSize", const.DEFAULT_PAGE_SIZE)
+    try:
+        page_num = int(page_num)
+        page_size = int(page_size)
+    except Exception as e:
+        return wrap_response(404, msg="Invalid parameter: {}".format(e))
+    response = app.dataset_service.get_dataset_list(page_num, page_size)
+    return wrap_response(200, data=response)
+
+
+@app.route("/api/getConversationDataList", methods=["GET"])
+def get_conversation_data_list():
+    """
+       获取单元测试任务列表
+       :return:
+    """
+    dataset_id = request.args.get("datasetId", None)
+    if not dataset_id:
+        return wrap_response(404, msg='dataset_id is required')
+    page_num = request.args.get("pageNum", const.DEFAULT_PAGE_ID)
+    page_size = request.args.get("pageSize", const.DEFAULT_PAGE_SIZE)
+    try:
+        page_num = int(page_num)
+        page_size = int(page_size)
+    except Exception as e:
+        return wrap_response(404, msg="Invalid parameter: {}".format(e))
+    response = app.dataset_service.get_conversation_data_list(int(dataset_id), page_num, page_size)
+    return wrap_response(200, data=response)
+
+
 @app.errorhandler(werkzeug.exceptions.BadRequest)
 def handle_bad_request(e):
     logger.error(e)
@@ -661,9 +698,12 @@ if __name__ == '__main__':
     agent_db_engine = create_sql_engine(config['storage']['agent_state']['mysql'])
     app.session_maker = sessionmaker(bind=agent_db_engine)
 
-    dataset_server = DatasetServer(session_maker=sessionmaker(bind=agent_db_engine))
-    task_manager = TaskManager(session_maker=sessionmaker(bind=agent_db_engine), dataset_server=dataset_server)
+    dataset_service = DatasetService(session_maker=sessionmaker(bind=agent_db_engine))
+    app.dataset_service = dataset_service
+
+    task_manager = TaskManager(session_maker=sessionmaker(bind=agent_db_engine), dataset_service=dataset_service)
     app.task_manager = task_manager
+    task_manager.recover_tasks()
 
     wecom_db_config = config['storage']['user_relation']
     user_relation_manager = MySQLUserRelationManager(

+ 11 - 10
pqai_agent_server/const/status_enum.py

@@ -6,6 +6,7 @@ class TestTaskStatus(Enum):
     IN_PROGRESS = 1
     COMPLETED = 2
     CANCELLED = 3
+    FAILED = 4
 
     @property
     def description(self):
@@ -13,10 +14,19 @@ class TestTaskStatus(Enum):
             self.NOT_STARTED: "未开始",
             self.IN_PROGRESS: "进行中",
             self.COMPLETED: "已完成",
-            self.CANCELLED: "已取消"
+            self.CANCELLED: "已取消",
+            self.FAILED: "已失败"
         }
         return descriptions.get(self)
 
+# 使用示例
+def get_test_task_status_desc(status_code):
+    try:
+        status = TestTaskStatus(status_code)
+        return status.description
+    except ValueError:
+        return f"未知状态: {status_code}"
+
 class TestTaskConversationsStatus(Enum):
     """任务状态枚举类"""
     PENDING = 0  # 待执行
@@ -36,15 +46,6 @@ class TestTaskConversationsStatus(Enum):
         }
         return descriptions.get(self)
 
-
-# 使用示例
-def get_test_task_status_desc(status_code):
-    try:
-        status = TestTaskStatus(status_code)
-        return status.description
-    except ValueError:
-        return f"未知状态: {status_code}"
-
 # 使用示例
 def get_test_task_conversations_status_desc(status_code):
     try:

+ 22 - 0
pqai_agent_server/const/type_enum.py

@@ -0,0 +1,22 @@
+from enum import Enum
+
+
+class DatasetType(Enum):
+    INTERNAL = 0
+    EXTERNAL = 1
+
+    @property
+    def description(self):
+        descriptions = {
+            self.INTERNAL: "内部",
+            self.EXTERNAL: "外部"
+        }
+        return descriptions.get(self)
+
+# 使用示例
+def get_dataset_type_desc(type_code):
+    try:
+        type = DatasetType(type_code)
+        return type.description
+    except ValueError:
+        return f"未知类型: {type_code}"

+ 0 - 55
pqai_agent_server/dataset_server.py

@@ -1,55 +0,0 @@
-from typing import List
-
-from pqai_agent.data_models.dataset_model import DatasetModule
-from pqai_agent.data_models.internal_conversation_data import InternalConversationData
-from pqai_agent.data_models.qywx_chat_history import QywxChatHistory
-from pqai_agent.data_models.qywx_employee import QywxEmployee
-from pqai_agent_server.utils.odps_utils import ODPSUtils
-
-
-class DatasetServer:
-    def __init__(self, session_maker):
-        self.session_maker = session_maker
-        odps_utils = ODPSUtils()
-        self.odps_utils = odps_utils
-
-    def get_user_profile_data(self, third_party_user_id: str, date_version: str):
-        sql = f"""
-           SELECT * FROM third_party_user_date_version
-           WHERE dt between '20250612' and {date_version}  -- 添加分区条件
-           and third_party_user_id = {third_party_user_id}
-           and profile_data_v1 is not null 
-           order by dt desc 
-           limit 1
-           """
-        result_df = self.odps_utils.execute_sql(sql)
-
-        if not result_df.empty:
-            return result_df.iloc[0].to_dict()  # 获取第一行
-        return None
-
-    def get_dataset_list_by_module(self, module_id: int):
-        with self.session_maker() as session:
-            return session.query(DatasetModule).filter(DatasetModule.module_id == module_id).filter(
-                DatasetModule.is_delete == 0).all()
-
-    def get_conversation_data_list_by_dataset(self, dataset_id: int):
-        with self.session_maker() as session:
-            return session.query(InternalConversationData).filter(
-                InternalConversationData.dataset_id == dataset_id).filter(
-                DatasetModule.is_delete == 0).all()
-
-    def get_conversation_data_by_id(self, conversation_data_id: int):
-        with self.session_maker() as session:
-            return session.query(InternalConversationData).filter(
-                InternalConversationData.id == conversation_data_id).one()
-
-    def get_staff_profile_data(self, third_party_user_id: str):
-        with self.session_maker() as session:
-            return session.query(QywxEmployee).filter(
-                QywxEmployee.third_party_user_id == third_party_user_id).one()
-
-    def get_conversation_list_by_ids(self, conversation_ids: List[int]):
-        with self.session_maker() as session:
-            return session.query(QywxChatHistory).filter(QywxChatHistory.id in conversation_ids).all()
-

+ 143 - 0
pqai_agent_server/dataset_service.py

@@ -0,0 +1,143 @@
+import json
+from cgitb import reset
+from typing import List
+
+from sqlalchemy import func
+
+from pqai_agent.data_models.dataset_model import DatasetModule
+from pqai_agent.data_models.datasets import Datasets
+from pqai_agent.data_models.internal_conversation_data import InternalConversationData
+from pqai_agent.data_models.qywx_chat_history import QywxChatHistory
+from pqai_agent.data_models.qywx_employee import QywxEmployee
+from pqai_agent_server.const.type_enum import get_dataset_type_desc
+from pqai_agent_server.utils.odps_utils import ODPSUtils
+
+
+class DatasetService:
+    def __init__(self, session_maker):
+        self.session_maker = session_maker
+        odps_utils = ODPSUtils()
+        self.odps_utils = odps_utils
+
+    def get_user_profile_data(self, third_party_user_id: str, date_version: str):
+        sql = f"""
+           SELECT * FROM third_party_user_date_version
+           WHERE dt between '20250612' and {date_version}  -- 添加分区条件
+           and third_party_user_id = {third_party_user_id}
+           and profile_data_v1 is not null 
+           order by dt desc 
+           limit 1
+           """
+        result_df = self.odps_utils.execute_sql(sql)
+
+        if not result_df.empty:
+            return result_df.iloc[0].to_dict()  # 获取第一行
+        return None
+
+    def get_dataset_list_by_module(self, module_id: int):
+        with self.session_maker() as session:
+            return session.query(DatasetModule).filter(DatasetModule.module_id == module_id).filter(
+                DatasetModule.is_delete == 0).all()
+
+    def get_conversation_data_list_by_dataset(self, dataset_id: int):
+        with self.session_maker() as session:
+            return session.query(InternalConversationData).filter(
+                InternalConversationData.dataset_id == dataset_id).filter(
+                InternalConversationData.is_delete == 0).all()
+
+    def get_conversation_data_by_id(self, conversation_data_id: int):
+        with self.session_maker() as session:
+            return session.query(InternalConversationData).filter(
+                InternalConversationData.id == conversation_data_id).one()
+
+    def get_staff_profile_data(self, third_party_user_id: str):
+        with self.session_maker() as session:
+            return session.query(QywxEmployee).filter(
+                QywxEmployee.third_party_user_id == third_party_user_id).one()
+
+    def get_conversation_list_by_ids(self, conversation_ids: List[int]):
+        with self.session_maker() as session:
+            conversations = session.query(QywxChatHistory).filter(QywxChatHistory.id.in_(conversation_ids)).all()
+            result = []
+            for conversation in conversations:
+                data = {}
+                data["id"] = conversation.id
+                data["sender"] = conversation.sender
+                data["receiver"] = conversation.receiver
+                data["roomid"] = conversation.roomid
+                data["sendtime"] = conversation.sendtime / 1000
+                data["msg_type"] = conversation.msg_type
+                data["content"] = conversation.content
+                result.append(data)
+        return result
+
+    def get_dataset_list(self, page_num: int, page_size: int):
+        with self.session_maker() as session:
+            # 计算偏移量
+            offset = (page_num - 1) * page_size
+            # 查询分页数据
+            result = (session.query(Datasets)
+                      .filter(Datasets.is_delete == 0)
+                      .limit(page_size).offset(offset).all())
+            # 查询总记录数
+            total = session.query(func.count(Datasets.id)).filter(Datasets.is_delete == 0).scalar()
+
+            total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
+            total_page = 1 if total_page <= 0 else total_page
+            response_data = [
+                {
+                    "id": dataset.id,
+                    "name": dataset.name,
+                    "type": get_dataset_type_desc(dataset.type),
+                    "description": dataset.description,
+                    "createTime": dataset.create_time.strftime("%Y-%m-%d %H:%M:%S"),
+                    "updateTime": dataset.update_time.strftime("%Y-%m-%d %H:%M:%S")
+                }
+                for dataset in result
+            ]
+            return {
+                "currentPage": page_num,
+                "pageSize": page_size,
+                "totalSize": total_page,
+                "total": total,
+                "list": response_data,
+            }
+
+    def get_conversation_data_list(self, dataset_id: int, page_num: int, page_size: int):
+        with self.session_maker() as session:
+            # 计算偏移量
+            offset = (page_num - 1) * page_size
+            # 查询分页数据
+            result = (session.query(InternalConversationData)
+                      .filter(InternalConversationData.dataset_id == dataset_id)
+                      .filter(InternalConversationData.is_delete == 0)
+                      .limit(page_size).offset(offset).all())
+            # 查询总记录数
+            total = session.query(func.count(InternalConversationData.id)).filter(
+                InternalConversationData.is_delete == 0).scalar()
+
+            total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
+            total_page = 1 if total_page <= 0 else total_page
+            response_data = []
+            for conversation_data in result:
+                data = {}
+                data["id"] = conversation_data.id
+                data["datasetId"] = conversation_data.dataset_id
+                data["staff"] = self.get_staff_profile_data(conversation_data.staff_id).agent_profile
+                data["user"] = self.get_user_profile_data(conversation_data.user_id,
+                                                          conversation_data.version_date.replace("-", ""))['profile_data_v1']
+                data["conversation"] = self.get_conversation_list_by_ids(json.loads(conversation_data.conversation))
+                data["content"] = conversation_data.content
+                data["sendTime"] = conversation_data.send_time
+                data["sendType"] = conversation_data.send_type
+                data["userActiveRate"] = conversation_data.user_active_rate
+                data["createTime"]: conversation_data.create_time.strftime("%Y-%m-%d %H:%M:%S")
+                data["updateTime"]: conversation_data.update_time.strftime("%Y-%m-%d %H:%M:%S")
+                response_data.append(data)
+            return {
+                "currentPage": page_num,
+                "pageSize": page_size,
+                "totalSize": total_page,
+                "total": total,
+                "list": response_data,
+            }

+ 39 - 17
pqai_agent_server/task_server.py

@@ -1,10 +1,10 @@
-import threading
+import json
 import threading
 import time
 from concurrent.futures import ThreadPoolExecutor
+from datetime import datetime
 from typing import Dict
 
-from pyarrow.dataset import dataset
 from sqlalchemy import func
 
 from pqai_agent import logging_service
@@ -19,9 +19,9 @@ logger = logging_service.logger
 class TaskManager:
     """任务管理器"""
 
-    def __init__(self, session_maker, dataset_server, max_workers: int = 10):
+    def __init__(self, session_maker, dataset_service, max_workers: int = 10):
         self.session_maker = session_maker
-        self.dataset_server = dataset_server
+        self.dataset_service = dataset_service
         self.task_events = {}  # 任务ID -> Event (用于取消任务)
         self.task_locks = {}  # 任务ID -> Lock (用于任务状态同步)
         self.running_tasks = set()
@@ -105,9 +105,9 @@ class TaskManager:
                 session.flush()  # 强制SQL执行,但不提交事务
                 task_id = agent_test_task.id
                 agent_test_task_conversations = []
-                datasets_list = self.dataset_server.get_dataset_list_by_module(module_id)
+                datasets_list = self.dataset_service.get_dataset_list_by_module(module_id)
                 for datasets in datasets_list:
-                    conversation_datas = self.dataset_server.get_conversation_data_list_by_dataset(datasets.id)
+                    conversation_datas = self.dataset_service.get_conversation_data_list_by_dataset(datasets.id)
                     for conversation_data in conversation_datas:
                         agent_test_task_conversation = AgentTestTaskConversations(task_id=task_id, agent_id=agent_id,
                                                                                   dataset_id=datasets.id,
@@ -123,6 +123,11 @@ class TaskManager:
         with self.session_maker() as session:
             return session.query(AgentTestTask).filter(AgentTestTask.id == task_id).one()
 
+    def get_in_progress_task(self):
+        """获取执行中任务"""
+        with self.session_maker() as session:
+            return session.query(AgentTestTask).filter(AgentTestTask.status == TestTaskStatus.IN_PROGRESS.value).all()
+
     def get_task_conversations(self, task_id: int):
         """获取任务的所有子任务"""
         with self.session_maker() as session:
@@ -133,26 +138,32 @@ class TaskManager:
         with self.session_maker() as session:
             return session.query(AgentTestTaskConversations).filter(
                 AgentTestTaskConversations.task_id == task_id).filter(
-                AgentTestTaskConversations.status == TestTaskConversationsStatus.PENDING.value).all()
+                AgentTestTaskConversations.status.in_([
+                    TestTaskConversationsStatus.PENDING.value,
+                    TestTaskConversationsStatus.RUNNING.value
+                ])).all()
 
     def update_task_status(self, task_id: int, status: int):
         """更新任务状态"""
         with self.session_maker() as session:
-            session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update({"status": status})
+            session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update(
+                {"status": status, "update_time": datetime.now()})
             session.commit()
 
     def update_task_conversations_status(self, task_conversations_id: int, status: int):
         """更新子任务状态"""
         with self.session_maker() as session:
             session.query(AgentTestTaskConversations).filter(
-                AgentTestTaskConversations.id == task_conversations_id).update({"status": status})
+                AgentTestTaskConversations.id == task_conversations_id).update(
+                {"status": status, "update_time": datetime.now()})
             session.commit()
 
     def update_task_conversations_res(self, task_conversations_id: int, status: int, score: str):
         """更新子任务结果"""
         with self.session_maker() as session:
             session.query(AgentTestTaskConversations).filter(
-                AgentTestTaskConversations.id == task_conversations_id).update({"status": status, "score": score})
+                AgentTestTaskConversations.id == task_conversations_id).update(
+                {"status": status, "score": score, "update_time": datetime.now()})
             session.commit()
 
     def cancel_task(self, task_id: int):
@@ -193,6 +204,16 @@ class TaskManager:
         logger.info(f"Resumed task {task_id}")
         return True
 
+    def recover_tasks(self):
+        """服务启动时恢复未完成的任务"""
+        # 获取所有未完成的任务ID(根据实际状态定义查询)
+        in_progress_tasks = self.get_in_progress_task()
+
+        for task in in_progress_tasks:
+            task_id = task.id
+            # 重新提交任务
+            self._execute_task(task_id)
+
     def _execute_task(self, task_id: int):
         """提交任务到线程池执行"""
         # 确保任务状态一致性
@@ -229,14 +250,15 @@ class TaskManager:
                     break
 
                 # 更新子任务状态为运行中
-                self.update_task_conversations_status(task_conversation.id,
-                                                      TestTaskConversationsStatus.RUNNING.value)
+                if task_conversation.status == TestTaskConversationsStatus.PENDING.value:
+                    self.update_task_conversations_status(task_conversation.id,
+                                                          TestTaskConversationsStatus.RUNNING.value)
                 try:
-                    conversation_data = self.dataset_server.get_conversation_data_by_id(
+                    conversation_data = self.dataset_service.get_conversation_data_by_id(
                         task_conversation.conversation_id)
-                    user_profile_data = self.dataset_server.get_user_profile_data(conversation_data.user_id)
-                    staff_profile_data = self.dataset_server.get_staff_profile_data(conversation_data.staff_id)
-
+                    user_profile_data = self.dataset_service.get_user_profile_data(conversation_data.user_id)['profile_data_v1']
+                    staff_profile_data = self.dataset_service.get_staff_profile_data(conversation_data.staff_id).agent_profile
+                    conversations = self.dataset_service.get_conversation_list_by_ids(json.loads(conversation_data.conversation))
 
                     # 模拟任务执行 - 在实际应用中替换为实际业务逻辑
                     # TODO 后续改成实际任务执行
@@ -270,7 +292,7 @@ class TaskManager:
                     if all_completed else TestTaskStatus.CANCELLED.value)
         except Exception as e:
             logger.error(f"Error executing task {task_id}: {str(e)}")
-            self.update_task_status(task_id, TestTaskStatus.COMPLETED.value)
+            self.update_task_status(task_id, TestTaskStatus.FAILED.value)
         finally:
             # 清理资源
             with self.task_locks[task_id]: