|
@@ -72,7 +72,10 @@ class AgentService:
|
|
|
self.chat_service_type = chat_service_type
|
|
|
|
|
|
# 定时任务调度器
|
|
|
- self.scheduler = BackgroundScheduler()
|
|
|
+ self.scheduler = None
|
|
|
+ self.scheduler_mode = self.config.get('system', {}).get('scheduler_mode', 'local')
|
|
|
+ self.scheduler_queue = None
|
|
|
+ self.msg_scheduler_thread = None
|
|
|
self.limit_initiative_conversation_rate = True
|
|
|
self.running = False
|
|
|
self.process_thread = None
|
|
@@ -86,6 +89,43 @@ class AgentService:
|
|
|
apscheduler.triggers.cron.CronTrigger(**schedule_params)
|
|
|
)
|
|
|
|
|
|
+ def setup_scheduler(self):
|
|
|
+ self.scheduler = BackgroundScheduler()
|
|
|
+ if self.scheduler_mode == 'mq':
|
|
|
+ logging.info("setup event message scheduler with MQ")
|
|
|
+ mq_conf = self.config['mq']
|
|
|
+ topic = mq_conf['scheduler_topic']
|
|
|
+ self.scheduler_queue = AliyunRocketMQQueueBackend(
|
|
|
+ mq_conf['endpoints'],
|
|
|
+ mq_conf['instance_id'],
|
|
|
+ topic,
|
|
|
+ has_consumer=True, has_producer=True,
|
|
|
+ group_id=mq_conf['scheduler_group'],
|
|
|
+ topic_type='DELAY'
|
|
|
+ )
|
|
|
+ self.msg_scheduler_thread = threading.Thread(target=self.process_scheduler_events)
|
|
|
+ self.msg_scheduler_thread.start()
|
|
|
+ self.scheduler.start()
|
|
|
+
|
|
|
+ def process_scheduler_events(self):
|
|
|
+ while self.running:
|
|
|
+ msg = self.scheduler_queue.consume()
|
|
|
+ if msg:
|
|
|
+ try:
|
|
|
+ self.process_scheduler_event(msg)
|
|
|
+ self.scheduler_queue.ack(msg)
|
|
|
+ except Exception as e:
|
|
|
+ logger.error("Error processing scheduler event: {}".format(e))
|
|
|
+ time.sleep(1)
|
|
|
+ logger.info("Scheduler event processing thread exit")
|
|
|
+
|
|
|
+ def process_scheduler_event(self, msg: Message):
|
|
|
+ if msg.type == MessageType.AGGREGATION_TRIGGER:
|
|
|
+ # 延迟触发的消息,需放入接收队列以驱动Agent运转
|
|
|
+ self.receive_queue.produce(msg)
|
|
|
+ else:
|
|
|
+ logger.warning(f"Unknown message type: {msg.type}")
|
|
|
+
|
|
|
def _get_agent_instance(self, staff_id: str, user_id: str) -> DialogueManager:
|
|
|
"""获取Agent实例"""
|
|
|
agent_key = 'agent_{}_{}'.format(staff_id, user_id)
|
|
@@ -112,7 +152,11 @@ class AgentService:
|
|
|
self.running = True
|
|
|
self.process_thread = threading.Thread(target=service.process_messages)
|
|
|
self.process_thread.start()
|
|
|
- self.scheduler.start()
|
|
|
+ self.setup_scheduler()
|
|
|
+ # 只有企微场景需要主动发起
|
|
|
+ if not self.config['debug_flags'].get('disable_active_conversation', False):
|
|
|
+ schedule_param = self.config['agent_behavior'].get('active_conversation_schedule_param', None)
|
|
|
+ self.setup_initiative_conversations(schedule_param)
|
|
|
signal.signal(signal.SIGINT, self._handle_sigint)
|
|
|
if blocking:
|
|
|
self.process_thread.join()
|
|
@@ -124,6 +168,11 @@ class AgentService:
|
|
|
self.scheduler.shutdown()
|
|
|
if sync:
|
|
|
self.process_thread.join()
|
|
|
+ self.receive_queue.shutdown()
|
|
|
+ self.send_queue.shutdown()
|
|
|
+ if self.msg_scheduler_thread:
|
|
|
+ self.msg_scheduler_thread.join()
|
|
|
+ self.scheduler_queue.shutdown()
|
|
|
|
|
|
def _handle_sigint(self, signum, frame):
|
|
|
self._sigint_cnt += 1
|
|
@@ -151,12 +200,15 @@ class AgentService:
|
|
|
def _schedule_aggregation_trigger(self, staff_id: str, user_id: str, delay_sec: int):
|
|
|
logger.debug("user: {}, schedule trigger message after {} seconds".format(user_id, delay_sec))
|
|
|
message_ts = int((time.time() + delay_sec) * 1000)
|
|
|
- message = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.SYSTEM, user_id, staff_id, None, message_ts)
|
|
|
+ msg = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.SYSTEM, user_id, staff_id, None, message_ts)
|
|
|
# 系统消息使用特定的msgId,无实际意义
|
|
|
- message.msgId = -MessageType.AGGREGATION_TRIGGER.value
|
|
|
- self.scheduler.add_job(lambda: self.receive_queue.produce(message),
|
|
|
- 'date',
|
|
|
- run_date=datetime.now() + timedelta(seconds=delay_sec))
|
|
|
+ msg.msgId = -MessageType.AGGREGATION_TRIGGER.value
|
|
|
+ if self.scheduler_mode == 'mq':
|
|
|
+ self.scheduler_queue.produce(msg)
|
|
|
+ else:
|
|
|
+ self.scheduler.add_job(lambda: self.receive_queue.produce(msg),
|
|
|
+ 'date',
|
|
|
+ run_date=datetime.now() + timedelta(seconds=delay_sec))
|
|
|
|
|
|
def process_single_message(self, message: Message):
|
|
|
user_id = message.sender
|
|
@@ -358,13 +410,15 @@ if __name__ == "__main__":
|
|
|
config['mq']['instance_id'],
|
|
|
config['mq']['receive_topic'],
|
|
|
has_consumer=True, has_producer=True,
|
|
|
- group_id=config['mq']['receive_group']
|
|
|
+ group_id=config['mq']['receive_group'],
|
|
|
+ topic_type='FIFO'
|
|
|
)
|
|
|
send_queue = AliyunRocketMQQueueBackend(
|
|
|
config['mq']['endpoints'],
|
|
|
config['mq']['instance_id'],
|
|
|
config['mq']['send_topic'],
|
|
|
- has_consumer=False, has_producer=True
|
|
|
+ has_consumer=False, has_producer=True,
|
|
|
+ topic_type='FIFO'
|
|
|
)
|
|
|
else:
|
|
|
receive_queue = MemoryQueueBackend()
|
|
@@ -390,8 +444,6 @@ if __name__ == "__main__":
|
|
|
wecom_db_config['table']['user']
|
|
|
)
|
|
|
|
|
|
-
|
|
|
-
|
|
|
# 创建Agent服务
|
|
|
service = AgentService(
|
|
|
receive_backend=receive_queue,
|
|
@@ -401,11 +453,6 @@ if __name__ == "__main__":
|
|
|
user_relation_manager=user_relation_manager,
|
|
|
chat_service_type=ChatServiceType.COZE_CHAT
|
|
|
)
|
|
|
- # 只有企微场景需要主动发起
|
|
|
- if not config['debug_flags'].get('disable_active_conversation', False):
|
|
|
- schedule_param = config['agent_behavior'].get('active_conversation_schedule_param', None)
|
|
|
- service.setup_initiative_conversations(schedule_param)
|
|
|
-
|
|
|
|
|
|
if not config['debug_flags'].get('console_input', False):
|
|
|
service.start(blocking=True)
|
|
@@ -414,7 +461,7 @@ if __name__ == "__main__":
|
|
|
service.start()
|
|
|
|
|
|
message_id = 0
|
|
|
- while True:
|
|
|
+ while service.running:
|
|
|
print("Input next message: ")
|
|
|
text = sys.stdin.readline().strip()
|
|
|
if not text:
|