Browse Source

修改创建任务为异步操作

xueyiming 5 ngày trước cách đây
mục cha
commit
d002b019e3

+ 2 - 2
pqai_agent/data_models/agent_test_task.py

@@ -12,8 +12,8 @@ class AgentTestTask(Base):
     module_id = Column(BigInteger, nullable=False, comment="model主键")
     create_user = Column(String(32), nullable=True, comment="创建用户")
     update_user = Column(String(32), nullable=True, comment="更新用户")
-    dataset_ids = Column(Text, nullable=False, comment="数据集ids")
-    status = Column(Integer, default=0, nullable=False, comment="状态(0:未开始, 1:进行中, 2:已完成, 3:已取消)")
+    dataset_ids = Column(Text, nullable=True, comment="数据集ids")
+    status = Column(Integer, nullable=True, comment="状态(0:未开始, 1:进行中, 2:已完成, 3:已取消)")
     create_time = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP", comment="创建时间")
     update_time = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP", onupdate="CURRENT_TIMESTAMP",
                          comment="更新时间")

+ 6 - 1
pqai_agent_server/const/status_enum.py

@@ -7,6 +7,9 @@ class TestTaskStatus(Enum):
     COMPLETED = 2
     CANCELLED = 3
     FAILED = 4
+    CREATING = 5
+    CREATED_FAIL = 6
+
 
     @property
     def description(self):
@@ -15,7 +18,9 @@ class TestTaskStatus(Enum):
             self.IN_PROGRESS: "进行中",
             self.COMPLETED: "已完成",
             self.CANCELLED: "已取消",
-            self.FAILED: "已失败"
+            self.FAILED: "已失败",
+            self.CREATING: "生成任务中",
+            self.CREATED_FAIL:"生成任务失败"
         }
         return descriptions.get(self)
 

+ 5 - 2
pqai_agent_server/dataset_service.py

@@ -43,7 +43,9 @@ class DatasetService:
         with self.session_maker() as session:
             return session.query(InternalConversationData).filter(
                 InternalConversationData.dataset_id == dataset_id).filter(
-                InternalConversationData.is_delete == 0).all()
+                InternalConversationData.is_delete == 0).order_by(
+                InternalConversationData.id.asc()
+            ).all()
 
     def get_conversation_data_by_id(self, conversation_data_id: int):
         with self.session_maker() as session:
@@ -125,7 +127,8 @@ class DatasetService:
                 data["datasetId"] = conversation_data.dataset_id
                 data["staff"] = self.get_staff_profile_data(conversation_data.staff_id).agent_profile
                 data["user"] = self.get_user_profile_data(conversation_data.user_id,
-                                                          conversation_data.version_date.replace("-", ""))['profile_data_v1']
+                                                          conversation_data.version_date.replace("-", ""))[
+                    'profile_data_v1']
                 data["conversation"] = self.get_conversation_list_by_ids(json.loads(conversation_data.conversation))
                 data["content"] = conversation_data.content
                 data["sendTime"] = conversation_data.send_time

+ 97 - 24
pqai_agent_server/task_server.py

@@ -19,13 +19,14 @@ logger = logging_service.logger
 class TaskManager:
     """任务管理器"""
 
-    def __init__(self, session_maker, dataset_service, max_workers: int = 10):
+    def __init__(self, session_maker, dataset_service):
         self.session_maker = session_maker
         self.dataset_service = dataset_service
         self.task_events = {}  # 任务ID -> Event (用于取消任务)
         self.task_locks = {}  # 任务ID -> Lock (用于任务状态同步)
         self.running_tasks = set()
-        self.executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix='TaskWorker')
+        self.executor = ThreadPoolExecutor(max_workers=20, thread_name_prefix='TaskWorker')
+        self.create_task_executor = ThreadPoolExecutor(max_workers=10, thread_name_prefix='CreateTaskWorker')
         self.task_futures = {}  # 任务ID -> Future
 
     def get_test_task_list(self, page_num: int, page_size: int) -> Dict:
@@ -98,26 +99,71 @@ class TaskManager:
 
     def create_task(self, agent_id: int, module_id: int) -> Dict:
         """创建新任务"""
-        with (self.session_maker() as session):
-            with session.begin():
-                agent_test_task = AgentTestTask(agent_id=agent_id, module_id=module_id)
-                session.add(agent_test_task)
-                session.flush()  # 强制SQL执行,但不提交事务
-                task_id = agent_test_task.id
-                agent_test_task_conversations = []
-                datasets_list = self.dataset_service.get_dataset_list_by_module(module_id)
-                for datasets in datasets_list:
-                    conversation_datas = self.dataset_service.get_conversation_data_list_by_dataset(datasets.id)
-                    for conversation_data in conversation_datas:
-                        agent_test_task_conversation = AgentTestTaskConversations(task_id=task_id, agent_id=agent_id,
-                                                                                  dataset_id=datasets.id,
-                                                                                  conversation_id=conversation_data.id)
-                        agent_test_task_conversations.append(agent_test_task_conversation)
-                session.add_all(agent_test_task_conversations)
-        # 异步执行任务
-        self._execute_task(task_id)
+        with self.session_maker() as session:
+            agent_test_task = AgentTestTask(agent_id=agent_id, module_id=module_id,
+                                            status=TestTaskStatus.CREATING.value)
+            session.add(agent_test_task)
+            session.commit()  # 显式提交
+            task_id = agent_test_task.id
+        # 异步执行创建任务
+        self.create_task_executor.submit(self._generate_agent_test_task_conversation_batch, task_id, agent_id,
+                                         module_id)
         return self.get_task(task_id)
 
+    def _generate_agent_test_task_conversation_batch(self, task_id: int, agent_id: int, module_id: int):
+        """异步生成子任务"""
+        try:
+            # 获取数据集列表
+            datasets_list = self.dataset_service.get_dataset_list_by_module(module_id)
+
+            # 批量处理数据集 - 减少数据库交互
+            batch_size = 100  # 每批处理100个子任务
+            agent_test_task_conversation_batch = []
+
+            for dataset in datasets_list:
+                # 获取对话数据列表
+                conversation_datas = self.dataset_service.get_conversation_data_list_by_dataset(dataset.id)
+
+                for conversation_data in conversation_datas:
+                    # 创建子任务对象
+                    agent_test_task_conversation = AgentTestTaskConversations(
+                        task_id=task_id,
+                        agent_id=agent_id,
+                        dataset_id=dataset.id,
+                        conversation_id=conversation_data.id,
+                        status=TestTaskConversationsStatus.PENDING.value
+                    )
+                    agent_test_task_conversation_batch.append(agent_test_task_conversation)
+
+                    # 批量提交
+                    if len(agent_test_task_conversation_batch) >= batch_size:
+                        self.save_agent_test_task_conversation_batch(agent_test_task_conversation_batch)
+                        agent_test_task_conversation_batch = []
+
+            # 提交剩余的子任务
+            if agent_test_task_conversation_batch:
+                self.save_agent_test_task_conversation_batch(agent_test_task_conversation_batch)
+
+            # 更新主任务状态为未开始
+            self.update_task_status(task_id, TestTaskStatus.NOT_STARTED.value)
+
+            # 自动提交任务执行
+            self._execute_task(task_id)
+
+        except Exception as e:
+            logger.error(f"生成子任务失败: {str(e)}")
+            # 更新任务状态为失败
+            self.update_task_status(task_id, TestTaskStatus.CREATED_FAIL.value)
+
+    def save_agent_test_task_conversation_batch(self, agent_test_task_conversation_batch: list):
+        """批量保存子任务到数据库"""
+        try:
+            with self.session_maker() as session:
+                with session.begin():
+                    session.add_all(agent_test_task_conversation_batch)
+        except Exception as e:
+            logger.error(e)
+
     def get_task(self, task_id: int):
         """获取任务信息"""
         with self.session_maker() as session:
@@ -128,11 +174,22 @@ class TaskManager:
         with self.session_maker() as session:
             return session.query(AgentTestTask).filter(AgentTestTask.status == TestTaskStatus.IN_PROGRESS.value).all()
 
+    def get_creating_task(self):
+        """获取执行中任务"""
+        with self.session_maker() as session:
+            return session.query(AgentTestTask).filter(AgentTestTask.status == TestTaskStatus.CREATING.value).all()
+
     def get_task_conversations(self, task_id: int):
         """获取任务的所有子任务"""
         with self.session_maker() as session:
             return session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).all()
 
+    def del_task_conversations(self, task_id: int):
+        with self.session_maker() as session:
+            session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).delete()
+            # 提交事务生效
+            session.commit()
+
     def get_pending_task_conversations(self, task_id: int):
         """获取待处理的子任务"""
         with self.session_maker() as session:
@@ -206,7 +263,19 @@ class TaskManager:
 
     def recover_tasks(self):
         """服务启动时恢复未完成的任务"""
-        # 获取所有未完成的任务ID(根据实际状态定义查询)
+
+        creating = self.get_creating_task()
+        for task in creating:
+            task_id = task.id
+            agent_id = task.agent_id
+            module_id = task.module_id
+            self.del_task_conversations(task_id)
+            # 重新提交任务
+            # 异步执行创建任务
+            self.create_task_executor.submit(self._generate_agent_test_task_conversation_batch, task_id, agent_id,
+                                             module_id)
+
+        # 获取所有进行中的任务ID(根据实际状态定义查询)
         in_progress_tasks = self.get_in_progress_task()
 
         for task in in_progress_tasks:
@@ -256,9 +325,13 @@ class TaskManager:
                 try:
                     conversation_data = self.dataset_service.get_conversation_data_by_id(
                         task_conversation.conversation_id)
-                    user_profile_data = self.dataset_service.get_user_profile_data(conversation_data.user_id)['profile_data_v1']
-                    staff_profile_data = self.dataset_service.get_staff_profile_data(conversation_data.staff_id).agent_profile
-                    conversations = self.dataset_service.get_conversation_list_by_ids(json.loads(conversation_data.conversation))
+                    user_profile_data = self.dataset_service.get_user_profile_data(conversation_data.user_id,
+                                                                                   conversation_data.version_date.replace(
+                                                                                       "-", ""))['profile_data_v1']
+                    staff_profile_data = self.dataset_service.get_staff_profile_data(
+                        conversation_data.staff_id).agent_profile
+                    conversations = self.dataset_service.get_conversation_list_by_ids(
+                        json.loads(conversation_data.conversation))
 
                     # 模拟任务执行 - 在实际应用中替换为实际业务逻辑
                     # TODO 后续改成实际任务执行