Explorar o código

Update agent_service: resume push tasks when service restarts

StrayWarrior hai 1 mes
pai
achega
e7cf918ca7
Modificáronse 2 ficheiros con 25 adicións e 2 borrados
  1. 20 1
      pqai_agent/agent_service.py
  2. 5 1
      pqai_agent/push_service.py

+ 20 - 1
pqai_agent/agent_service.py

@@ -84,6 +84,8 @@ class AgentService:
         self.push_task_producer = None
         self.push_task_producer = None
         self.push_task_consumer = None
         self.push_task_consumer = None
         self._init_push_task_queue()
         self._init_push_task_queue()
+        self.next_push_disabled = True
+        self._resume_unfinished_push_task()
 
 
         self.send_rate_limiter = MessageSenderRateLimiter()
         self.send_rate_limiter = MessageSenderRateLimiter()
 
 
@@ -305,8 +307,24 @@ class AgentService:
         self.push_task_consumer.startup()
         self.push_task_consumer.startup()
         self.push_task_consumer.subscribe(rmq_topic)
         self.push_task_consumer.subscribe(rmq_topic)
 
 
+
+    def _resume_unfinished_push_task(self):
+        def run_unfinished_push_task():
+            logger.info("start to resume unfinished push task")
+            push_task_worker_pool = PushTaskWorkerPool(
+                self, self.push_task_rmq_topic, self.push_task_consumer, self.push_task_producer)
+            push_task_worker_pool.start()
+            push_task_worker_pool.wait_to_finish()
+            self.next_push_disabled = False
+            logger.info("unfinished push tasks should be finished")
+        thread = threading.Thread(target=run_unfinished_push_task)
+        thread.start()
+
     def _check_initiative_conversations(self):
     def _check_initiative_conversations(self):
         logger.info("start to check initiative conversations")
         logger.info("start to check initiative conversations")
+        if self.next_push_disabled:
+            logger.info("previous push tasks in processing, next push is disabled")
+            return
         if not DialogueManager.is_time_suitable_for_active_conversation():
         if not DialogueManager.is_time_suitable_for_active_conversation():
             logger.info("time is not suitable for active conversation")
             logger.info("time is not suitable for active conversation")
             return
             return
@@ -324,7 +342,8 @@ class AgentService:
         push_task_worker_pool.start()
         push_task_worker_pool.start()
         for thread in push_scan_threads:
         for thread in push_scan_threads:
             thread.join()
             thread.join()
-        # 先等待生成任务全部提交,再等待任务处理线程池完成
+        # 由于扫描和生成异步,两次扫描之间可能消息并未处理完,会有重复生成任务的情况,因此需等待上一次任务结束
+        # 问题在于,如果每次创建出新的PushTaskWorkerPool,在上次任务有未处理完的消息即退出时,会有未处理的消息堆积
         push_task_worker_pool.wait_to_finish()
         push_task_worker_pool.wait_to_finish()
 
 
     def _check_initiative_conversations_v1(self):
     def _check_initiative_conversations_v1(self):

+ 5 - 1
pqai_agent/push_service.py

@@ -1,6 +1,7 @@
 import json
 import json
 import time
 import time
 import uuid
 import uuid
+from datetime import datetime
 from enum import Enum
 from enum import Enum
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
 from threading import Thread
 from threading import Thread
@@ -84,6 +85,8 @@ class PushTaskWorkerPool:
         self.loop_thread.start()
         self.loop_thread.start()
 
 
     def process_push_tasks(self):
     def process_push_tasks(self):
+        # RMQ consumer疑似有bug,创建后立即消费可能报NPE
+        time.sleep(1)
         while True:
         while True:
             msgs = self.consumer.receive(1, 300)
             msgs = self.consumer.receive(1, 300)
             if not msgs:
             if not msgs:
@@ -103,7 +106,8 @@ class PushTaskWorkerPool:
                     continue
                     continue
             msg = msgs[0]
             msg = msgs[0]
             task = json.loads(msg.body.decode('utf-8'))
             task = json.loads(msg.body.decode('utf-8'))
-            logger.debug(f"recv message: {task}")
+            msg_time = datetime.fromtimestamp(task['timestamp'] / 1000).strftime("%Y-%m-%d %H:%M:%S")
+            logger.debug(f"recv message:{msg_time} - {task}")
             if task['task_type'] == TaskType.GENERATE.value:
             if task['task_type'] == TaskType.GENERATE.value:
                 self.generate_executor.submit(self.handle_generate_task, task, msg)
                 self.generate_executor.submit(self.handle_generate_task, task, msg)
             elif task['task_type'] == TaskType.SEND.value:
             elif task['task_type'] == TaskType.SEND.value: