Explorar o código

增加并发操作

xueyiming hai 3 días
pai
achega
190ab150e5
Modificáronse 1 ficheiros con 256 adicións e 92 borrados
  1. 256 92
      pqai_agent_server/task_server.py

+ 256 - 92
pqai_agent_server/task_server.py

@@ -1,5 +1,6 @@
 import json
 import threading
+import concurrent.futures
 import time
 from concurrent.futures import ThreadPoolExecutor
 from datetime import datetime
@@ -15,7 +16,7 @@ from pqai_agent.data_models.agent_test_task import AgentTestTask
 from pqai_agent.data_models.agent_test_task_conversations import AgentTestTaskConversations
 from pqai_agent.data_models.service_module import ServiceModule
 from pqai_agent_server.const.status_enum import TestTaskConversationsStatus, TestTaskStatus, get_test_task_status_desc
-
+from concurrent.futures import ThreadPoolExecutor
 logger = logging_service.logger
 
 
@@ -324,107 +325,270 @@ class TaskManager:
             self.running_tasks.add(task_id)
 
     def _process_task(self, task_id: int):
-        """处理任务的所有子任务"""
+        """处理任务的所有子任务(并发执行)"""
         try:
-            # 更新任务状态为运行中
             self.update_task_status(task_id, TestTaskStatus.IN_PROGRESS.value)
-
-            # 获取所有待处理的子任务
             task_conversations = self.get_pending_task_conversations(task_id)
 
+            if not task_conversations:
+                self.update_task_status(task_id, TestTaskStatus.COMPLETED.value)
+                return
+
             agent_configuration = self.get_agent_configuration_by_task_id(task_id)
             query_prompt_template = agent_configuration.task_prompt
-            agent = MultiModalChatAgent(model=agent_configuration.execution_model,
-                                        system_prompt=agent_configuration.system_prompt,
-                                        tools=json.loads(agent_configuration.tools))
-            # 执行每个子任务
-            for task_conversation in task_conversations:
-                # 检查任务是否被取消
-                if self.task_events[task_id].is_set():
-                    break
-
-                # 更新子任务状态为运行中
-                if task_conversation.status == TestTaskConversationsStatus.PENDING.value:
-                    self.update_task_conversations_status(task_conversation.id,
-                                                          TestTaskConversationsStatus.RUNNING.value)
-                try:
-                    conversation_data = self.dataset_service.get_conversation_data_by_id(
-                        task_conversation.conversation_id)
-                    user_profile_data = self.dataset_service.get_user_profile_data(conversation_data.user_id,
-                                                                                   conversation_data.version_date.replace(
-                                                                                       "-", ""))
-                    user_profile = json.loads(user_profile_data['profile_data_v1'])
-                    avatar = user_profile_data['iconurl']
-                    staff_profile_data = self.dataset_service.get_staff_profile_data(
-                        conversation_data.staff_id).agent_profile
-                    conversations = self.dataset_service.get_chat_conversation_list_by_ids(
-                        json.loads(conversation_data.conversation), conversation_data.staff_id)
-                    conversations = sorted(conversations, key=lambda i: i['timestamp'], reverse=False)
-
-                    last_timestamp = int(conversations[-1]["timestamp"])
-                    push_time = int(last_timestamp / 1000) + 24 * 3600
-                    push_dt = datetime.fromtimestamp(push_time).strftime('%Y-%m-%d %H:%M:%S')
-                    push_message = agent._generate_message(
-                        context={
-                            "formatted_staff_profile": staff_profile_data,
-                            "nickname": user_profile['nickname'],
-                            "name": user_profile['name'],
-                            "avatar": avatar,
-                            "preferred_nickname": user_profile['preferred_nickname'],
-                            "gender": user_profile['gender'],
-                            "age": user_profile['age'],
-                            "region": user_profile['region'],
-                            "health_conditions": user_profile['health_conditions'],
-                            "medications": user_profile['medications'],
-                            "interests": user_profile['interests'],
-                            "current_datetime": push_dt
-                        },
-                        dialogue_history=conversations,
-                        query_prompt_template=query_prompt_template
+
+            # 使用线程池执行子任务
+            with ThreadPoolExecutor(max_workers=20) as executor:  # 可根据需要调整并发数
+                futures = {}
+                for task_conversation in task_conversations:
+                    if self.task_events[task_id].is_set():
+                        break  # 检查任务取消事件
+
+                    # 提交子任务到线程池
+                    future = executor.submit(
+                        self._process_single_conversation,
+                        task_id,
+                        task_conversation,
+                        query_prompt_template,
+                        agent_configuration
                     )
-                    # TODO 获取打分
-                    score = '{"score":0.05}'
-                    # 更新子任务状态为已完成
-                    self.update_task_conversations_res(task_conversation.id,
-                                                       TestTaskConversationsStatus.SUCCESS.value,
-                                                       json.dumps(conversations, ensure_ascii=False),
-                                                       json.dumps(push_message, ensure_ascii=False),
-                                                       score)
-                except Exception as e:
-                    logger.error(f"Error executing task {task_id}: {str(e)}")
-                    self.update_task_conversations_status(task_conversation.id,
-                                                          TestTaskConversationsStatus.FAILED.value)
-
-            # 检查任务是否完成
-            task_conversations = self.get_task_conversations(task_id)
-            all_completed = all(task_conversation.status in
-                                (TestTaskConversationsStatus.SUCCESS.value, TestTaskConversationsStatus.FAILED.value)
-                                for task_conversation in task_conversations)
-            any_pending = any(task_conversation.status in
-                              (TestTaskConversationsStatus.PENDING.value, TestTaskConversationsStatus.RUNNING.value)
-                              for task_conversation in task_conversations)
-
-            if all_completed:
-                self.update_task_status(task_id, TestTaskStatus.COMPLETED.value)
-                logger.info(f"Task {task_id} completed")
-            elif not any_pending:
-                # 没有待处理子任务但未全部完成(可能是取消了)
-                current_status = self.get_task(task_id).status
-                if current_status != TestTaskStatus.CANCELLED.value:
-                    self.update_task_status(task_id, TestTaskStatus.COMPLETED.value
-                    if all_completed else TestTaskStatus.CANCELLED.value)
+                    futures[future] = task_conversation.id
+
+                # 等待所有子任务完成或取消
+                for future in concurrent.futures.as_completed(futures):
+                    conv_id = futures[future]
+                    try:
+                        # 设置单个任务超时时间(秒),可根据任务复杂度调整
+                        future.result()  # 获取结果(如有异常会在此抛出)
+                    except Exception as e:
+                        logger.error(f"Subtask {conv_id} failed: {str(e)}")
+                        self.update_task_conversations_status(
+                            conv_id,
+                            TestTaskConversationsStatus.FAILED.value
+                        )
+
+            # 检查最终任务状态
+            self._update_final_task_status(task_id)
+
         except Exception as e:
-            logger.error(f"Error executing task {task_id}: {str(e)}")
+            logger.error(f"Error processing task {task_id}: {str(e)}")
             self.update_task_status(task_id, TestTaskStatus.FAILED.value)
         finally:
-            # 清理资源
-            with self.task_locks[task_id]:
-                if task_id in self.running_tasks:
-                    self.running_tasks.remove(task_id)
-                if task_id in self.task_events:
-                    del self.task_events[task_id]
-                if task_id in self.task_futures:
-                    del self.task_futures[task_id]
+            self._cleanup_task_resources(task_id)
+
+    def _process_single_conversation(self, task_id, task_conversation, query_prompt_template, agent_configuration):
+        """处理单个对话子任务(线程安全)"""
+        # 检查任务是否被取消
+        if self.task_events[task_id].is_set():
+            return
+
+        # 更新子任务状态
+        if task_conversation.status == TestTaskConversationsStatus.PENDING.value:
+            self.update_task_conversations_status(
+                task_conversation.id,
+                TestTaskConversationsStatus.RUNNING.value
+            )
+
+        try:
+            # 创建独立的agent实例(确保线程安全)
+            agent = MultiModalChatAgent(
+                model=agent_configuration.execution_model,
+                system_prompt=agent_configuration.system_prompt,
+                tools=json.loads(agent_configuration.tools)
+            )
+
+            # 获取对话数据(与原始代码相同)
+            conversation_data = self.dataset_service.get_conversation_data_by_id(
+                task_conversation.conversation_id)
+            user_profile_data = self.dataset_service.get_user_profile_data(
+                conversation_data.user_id,
+                conversation_data.version_date.replace("-", ""))
+            user_profile = json.loads(user_profile_data['profile_data_v1'])
+            avatar = user_profile_data['iconurl']
+            staff_profile_data = self.dataset_service.get_staff_profile_data(
+                conversation_data.staff_id).agent_profile
+            conversations = self.dataset_service.get_chat_conversation_list_by_ids(
+                json.loads(conversation_data.conversation),
+                conversation_data.staff_id
+            )
+            conversations = sorted(conversations, key=lambda i: i['timestamp'], reverse=False)
+
+            # 生成推送消息(与原始代码相同)
+            last_timestamp = int(conversations[-1]["timestamp"])
+            push_time = int(last_timestamp / 1000) + 24 * 3600
+            push_dt = datetime.fromtimestamp(push_time).strftime('%Y-%m-%d %H:%M:%S')
+            push_message = agent._generate_message(
+                context={
+                    "formatted_staff_profile": staff_profile_data,
+                    "nickname": user_profile['nickname'],
+                    "name": user_profile['name'],
+                    "avatar": avatar,
+                    "preferred_nickname": user_profile['preferred_nickname'],
+                    "gender": user_profile['gender'],
+                    "age": user_profile['age'],
+                    "region": user_profile['region'],
+                    "health_conditions": user_profile['health_conditions'],
+                    "medications": user_profile['medications'],
+                    "interests": user_profile['interests'],
+                    "current_datetime": push_dt
+                },
+                dialogue_history=conversations,
+                query_prompt_template=query_prompt_template
+            )
+
+            # 获取打分(TODO: 实际实现)
+            score = '{"score":0.05}'
+
+            # 更新子任务结果
+            self.update_task_conversations_res(
+                task_conversation.id,
+                TestTaskConversationsStatus.SUCCESS.value,
+                json.dumps(conversations, ensure_ascii=False),
+                json.dumps(push_message, ensure_ascii=False),
+                score
+            )
+
+        except Exception as e:
+            logger.error(f"Subtask {task_conversation.id} failed: {str(e)}")
+            self.update_task_conversations_status(
+                task_conversation.id,
+                TestTaskConversationsStatus.FAILED.value
+            )
+            raise  # 重新抛出异常以便主线程捕获
+
+    def _update_final_task_status(self, task_id):
+        """更新任务的最终状态"""
+        task_conversations = self.get_task_conversations(task_id)
+        all_completed = all(
+            conv.status in (TestTaskConversationsStatus.SUCCESS.value,
+                            TestTaskConversationsStatus.FAILED.value)
+            for conv in task_conversations
+        )
+
+        if all_completed:
+            self.update_task_status(task_id, TestTaskStatus.COMPLETED.value)
+            logger.info(f"Task {task_id} completed")
+        elif not any(
+                conv.status in (TestTaskConversationsStatus.PENDING.value,
+                                TestTaskConversationsStatus.RUNNING.value)
+                for conv in task_conversations
+        ):
+            current_status = self.get_task(task_id).status
+            if current_status != TestTaskStatus.CANCELLED.value:
+                new_status = TestTaskStatus.COMPLETED.value if all_completed else TestTaskStatus.CANCELLED.value
+                self.update_task_status(task_id, new_status)
+
+    def _cleanup_task_resources(self, task_id):
+        """清理任务资源(线程安全)"""
+        with self.task_locks[task_id]:
+            if task_id in self.running_tasks:
+                self.running_tasks.remove(task_id)
+            if task_id in self.task_events:
+                del self.task_events[task_id]
+            if task_id in self.task_futures:
+                del self.task_futures[task_id]
+    # def _process_task(self, task_id: int):
+    #     """处理任务的所有子任务"""
+    #     try:
+    #         # 更新任务状态为运行中
+    #         self.update_task_status(task_id, TestTaskStatus.IN_PROGRESS.value)
+    #
+    #         # 获取所有待处理的子任务
+    #         task_conversations = self.get_pending_task_conversations(task_id)
+    #
+    #         agent_configuration = self.get_agent_configuration_by_task_id(task_id)
+    #         query_prompt_template = agent_configuration.task_prompt
+    #         agent = MultiModalChatAgent(model=agent_configuration.execution_model,
+    #                                     system_prompt=agent_configuration.system_prompt,
+    #                                     tools=json.loads(agent_configuration.tools))
+    #         # 执行每个子任务
+    #         for task_conversation in task_conversations:
+    #             # 检查任务是否被取消
+    #             if self.task_events[task_id].is_set():
+    #                 break
+    #
+    #             # 更新子任务状态为运行中
+    #             if task_conversation.status == TestTaskConversationsStatus.PENDING.value:
+    #                 self.update_task_conversations_status(task_conversation.id,
+    #                                                       TestTaskConversationsStatus.RUNNING.value)
+    #             try:
+    #                 conversation_data = self.dataset_service.get_conversation_data_by_id(
+    #                     task_conversation.conversation_id)
+    #                 user_profile_data = self.dataset_service.get_user_profile_data(conversation_data.user_id,
+    #                                                                                conversation_data.version_date.replace(
+    #                                                                                    "-", ""))
+    #                 user_profile = json.loads(user_profile_data['profile_data_v1'])
+    #                 avatar = user_profile_data['iconurl']
+    #                 staff_profile_data = self.dataset_service.get_staff_profile_data(
+    #                     conversation_data.staff_id).agent_profile
+    #                 conversations = self.dataset_service.get_chat_conversation_list_by_ids(
+    #                     json.loads(conversation_data.conversation), conversation_data.staff_id)
+    #                 conversations = sorted(conversations, key=lambda i: i['timestamp'], reverse=False)
+    #
+    #                 last_timestamp = int(conversations[-1]["timestamp"])
+    #                 push_time = int(last_timestamp / 1000) + 24 * 3600
+    #                 push_dt = datetime.fromtimestamp(push_time).strftime('%Y-%m-%d %H:%M:%S')
+    #                 push_message = agent._generate_message(
+    #                     context={
+    #                         "formatted_staff_profile": staff_profile_data,
+    #                         "nickname": user_profile['nickname'],
+    #                         "name": user_profile['name'],
+    #                         "avatar": avatar,
+    #                         "preferred_nickname": user_profile['preferred_nickname'],
+    #                         "gender": user_profile['gender'],
+    #                         "age": user_profile['age'],
+    #                         "region": user_profile['region'],
+    #                         "health_conditions": user_profile['health_conditions'],
+    #                         "medications": user_profile['medications'],
+    #                         "interests": user_profile['interests'],
+    #                         "current_datetime": push_dt
+    #                     },
+    #                     dialogue_history=conversations,
+    #                     query_prompt_template=query_prompt_template
+    #                 )
+    #                 # TODO 获取打分
+    #                 score = '{"score":0.05}'
+    #                 # 更新子任务状态为已完成
+    #                 self.update_task_conversations_res(task_conversation.id,
+    #                                                    TestTaskConversationsStatus.SUCCESS.value,
+    #                                                    json.dumps(conversations, ensure_ascii=False),
+    #                                                    json.dumps(push_message, ensure_ascii=False),
+    #                                                    score)
+    #             except Exception as e:
+    #                 logger.error(f"Error executing task {task_id}: {str(e)}")
+    #                 self.update_task_conversations_status(task_conversation.id,
+    #                                                       TestTaskConversationsStatus.FAILED.value)
+    #
+    #         # 检查任务是否完成
+    #         task_conversations = self.get_task_conversations(task_id)
+    #         all_completed = all(task_conversation.status in
+    #                             (TestTaskConversationsStatus.SUCCESS.value, TestTaskConversationsStatus.FAILED.value)
+    #                             for task_conversation in task_conversations)
+    #         any_pending = any(task_conversation.status in
+    #                           (TestTaskConversationsStatus.PENDING.value, TestTaskConversationsStatus.RUNNING.value)
+    #                           for task_conversation in task_conversations)
+    #
+    #         if all_completed:
+    #             self.update_task_status(task_id, TestTaskStatus.COMPLETED.value)
+    #             logger.info(f"Task {task_id} completed")
+    #         elif not any_pending:
+    #             # 没有待处理子任务但未全部完成(可能是取消了)
+    #             current_status = self.get_task(task_id).status
+    #             if current_status != TestTaskStatus.CANCELLED.value:
+    #                 self.update_task_status(task_id, TestTaskStatus.COMPLETED.value
+    #                 if all_completed else TestTaskStatus.CANCELLED.value)
+    #     except Exception as e:
+    #         logger.error(f"Error executing task {task_id}: {str(e)}")
+    #         self.update_task_status(task_id, TestTaskStatus.FAILED.value)
+    #     finally:
+    #         # 清理资源
+    #         with self.task_locks[task_id]:
+    #             if task_id in self.running_tasks:
+    #                 self.running_tasks.remove(task_id)
+    #             if task_id in self.task_events:
+    #                 del self.task_events[task_id]
+    #             if task_id in self.task_futures:
+    #                 del self.task_futures[task_id]
 
     def shutdown(self):
         """关闭执行器"""