Selaa lähdekoodia

Update agent_service: support scheduler failover by delay queue

StrayWarrior 2 päivää sitten
vanhempi
commit
607b6a11aa
1 muutettua tiedostoa jossa 64 lisäystä ja 17 poistoa
  1. 64 17
      agent_service.py

+ 64 - 17
agent_service.py

@@ -72,7 +72,10 @@ class AgentService:
         self.chat_service_type = chat_service_type
 
         # 定时任务调度器
-        self.scheduler = BackgroundScheduler()
+        self.scheduler = None
+        self.scheduler_mode = self.config.get('system', {}).get('scheduler_mode', 'local')
+        self.scheduler_queue = None
+        self.msg_scheduler_thread = None
         self.limit_initiative_conversation_rate = True
         self.running = False
         self.process_thread = None
@@ -86,6 +89,43 @@ class AgentService:
             apscheduler.triggers.cron.CronTrigger(**schedule_params)
         )
 
+    def setup_scheduler(self):
+        self.scheduler = BackgroundScheduler()
+        if self.scheduler_mode == 'mq':
+            logging.info("setup event message scheduler with MQ")
+            mq_conf = self.config['mq']
+            topic = mq_conf['scheduler_topic']
+            self.scheduler_queue = AliyunRocketMQQueueBackend(
+                mq_conf['endpoints'],
+                mq_conf['instance_id'],
+                topic,
+                has_consumer=True, has_producer=True,
+                group_id=mq_conf['scheduler_group'],
+                topic_type='DELAY'
+            )
+            self.msg_scheduler_thread = threading.Thread(target=self.process_scheduler_events)
+            self.msg_scheduler_thread.start()
+        self.scheduler.start()
+
+    def process_scheduler_events(self):
+        while self.running:
+            msg = self.scheduler_queue.consume()
+            if msg:
+                try:
+                    self.process_scheduler_event(msg)
+                    self.scheduler_queue.ack(msg)
+                except Exception as e:
+                    logger.error("Error processing scheduler event: {}".format(e))
+            time.sleep(1)
+        logger.info("Scheduler event processing thread exit")
+
+    def process_scheduler_event(self, msg: Message):
+        if msg.type == MessageType.AGGREGATION_TRIGGER:
+            # 延迟触发的消息,需放入接收队列以驱动Agent运转
+            self.receive_queue.produce(msg)
+        else:
+            logger.warning(f"Unknown message type: {msg.type}")
+
     def _get_agent_instance(self, staff_id: str, user_id: str) -> DialogueManager:
         """获取Agent实例"""
         agent_key = 'agent_{}_{}'.format(staff_id, user_id)
@@ -112,7 +152,11 @@ class AgentService:
         self.running = True
         self.process_thread = threading.Thread(target=service.process_messages)
         self.process_thread.start()
-        self.scheduler.start()
+        self.setup_scheduler()
+        # 只有企微场景需要主动发起
+        if not self.config['debug_flags'].get('disable_active_conversation', False):
+            schedule_param = self.config['agent_behavior'].get('active_conversation_schedule_param', None)
+            self.setup_initiative_conversations(schedule_param)
         signal.signal(signal.SIGINT, self._handle_sigint)
         if blocking:
             self.process_thread.join()
@@ -124,6 +168,11 @@ class AgentService:
         self.scheduler.shutdown()
         if sync:
             self.process_thread.join()
+            self.receive_queue.shutdown()
+            self.send_queue.shutdown()
+            if self.msg_scheduler_thread:
+                self.msg_scheduler_thread.join()
+                self.scheduler_queue.shutdown()
 
     def _handle_sigint(self, signum, frame):
         self._sigint_cnt += 1
@@ -151,12 +200,15 @@ class AgentService:
     def _schedule_aggregation_trigger(self, staff_id: str, user_id: str, delay_sec: int):
         logger.debug("user: {}, schedule trigger message after {} seconds".format(user_id, delay_sec))
         message_ts = int((time.time() + delay_sec) * 1000)
-        message = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.SYSTEM, user_id, staff_id, None, message_ts)
+        msg = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.SYSTEM, user_id, staff_id, None, message_ts)
         # 系统消息使用特定的msgId,无实际意义
-        message.msgId = -MessageType.AGGREGATION_TRIGGER.value
-        self.scheduler.add_job(lambda: self.receive_queue.produce(message),
-                               'date',
-                               run_date=datetime.now() + timedelta(seconds=delay_sec))
+        msg.msgId = -MessageType.AGGREGATION_TRIGGER.value
+        if self.scheduler_mode == 'mq':
+            self.scheduler_queue.produce(msg)
+        else:
+            self.scheduler.add_job(lambda: self.receive_queue.produce(msg),
+                                   'date',
+                                   run_date=datetime.now() + timedelta(seconds=delay_sec))
 
     def process_single_message(self, message: Message):
         user_id = message.sender
@@ -358,13 +410,15 @@ if __name__ == "__main__":
             config['mq']['instance_id'],
             config['mq']['receive_topic'],
             has_consumer=True, has_producer=True,
-            group_id=config['mq']['receive_group']
+            group_id=config['mq']['receive_group'],
+            topic_type='FIFO'
         )
         send_queue = AliyunRocketMQQueueBackend(
             config['mq']['endpoints'],
             config['mq']['instance_id'],
             config['mq']['send_topic'],
-            has_consumer=False, has_producer=True
+            has_consumer=False, has_producer=True,
+            topic_type='FIFO'
         )
     else:
         receive_queue = MemoryQueueBackend()
@@ -390,8 +444,6 @@ if __name__ == "__main__":
             wecom_db_config['table']['user']
         )
 
-
-
     # 创建Agent服务
     service = AgentService(
         receive_backend=receive_queue,
@@ -401,11 +453,6 @@ if __name__ == "__main__":
         user_relation_manager=user_relation_manager,
         chat_service_type=ChatServiceType.COZE_CHAT
     )
-    # 只有企微场景需要主动发起
-    if not config['debug_flags'].get('disable_active_conversation', False):
-        schedule_param = config['agent_behavior'].get('active_conversation_schedule_param', None)
-        service.setup_initiative_conversations(schedule_param)
-
 
     if not config['debug_flags'].get('console_input', False):
         service.start(blocking=True)
@@ -414,7 +461,7 @@ if __name__ == "__main__":
         service.start()
 
     message_id = 0
-    while True:
+    while service.running:
         print("Input next message: ")
         text = sys.stdin.readline().strip()
         if not text: