Explorar o código

Update push_service: split send and generate consumer

StrayWarrior hai 2 días
pai
achega
6458107cc7
Modificáronse 2 ficheiros con 46 adicións e 20 borrados
  1. 22 8
      pqai_agent/agent_service.py
  2. 24 12
      pqai_agent/push_service.py

+ 22 - 8
pqai_agent/agent_service.py

@@ -15,9 +15,10 @@ import traceback
 import apscheduler.triggers.cron
 import rocketmq
 from apscheduler.schedulers.background import BackgroundScheduler
+from rocketmq import FilterExpression
 from sqlalchemy.orm import sessionmaker
 
-from pqai_agent import configs
+from pqai_agent import configs, push_service
 from pqai_agent.abtest.utils import get_abtest_info
 from pqai_agent.agent_config_manager import AgentConfigManager
 from pqai_agent.agents.message_reply_agent import MessageReplyAgent
@@ -96,7 +97,8 @@ class AgentService:
 
         # Push相关
         self.push_task_producer = None
-        self.push_task_consumer = None
+        self.push_generate_task_consumer = None
+        self.push_send_task_consumer = None
         self._init_push_task_queue()
         self.next_push_disabled = True
         self._resume_unfinished_push_task()
@@ -384,20 +386,31 @@ class AgentService:
         mq_conf = configs.get()['mq']
         rmq_client_conf = rocketmq.ClientConfiguration(mq_conf['endpoints'], credentials, mq_conf['instance_id'])
         rmq_topic = mq_conf['push_tasks_topic']
-        rmq_group = mq_conf['push_tasks_group']
+        rmq_group_generate = mq_conf['push_generate_task_group']
+        rmq_group_send = mq_conf['push_send_task_group']
         self.push_task_rmq_topic = rmq_topic
         self.push_task_producer = rocketmq.Producer(rmq_client_conf, (rmq_topic,))
         self.push_task_producer.startup()
-        self.push_task_consumer = rocketmq.SimpleConsumer(rmq_client_conf, rmq_group, await_duration=5)
-        self.push_task_consumer.startup()
-        self.push_task_consumer.subscribe(rmq_topic)
+        # FIXME: 不应该暴露到agent service中
+        self.push_generate_task_consumer = rocketmq.SimpleConsumer(rmq_client_conf, rmq_group_generate, await_duration=5)
+        self.push_generate_task_consumer.startup()
+        self.push_generate_task_consumer.subscribe(
+            rmq_topic, filter_expression=FilterExpression(push_service.TaskType.GENERATE.value)
+        )
+        self.push_send_task_consumer = rocketmq.SimpleConsumer(rmq_client_conf, rmq_group_send, await_duration=5)
+        self.push_send_task_consumer.startup()
+        self.push_send_task_consumer.subscribe(
+            rmq_topic, filter_expression=FilterExpression(push_service.TaskType.SEND.value)
+        )
 
 
     def _resume_unfinished_push_task(self):
         def run_unfinished_push_task():
             logger.info("start to resume unfinished push task")
             push_task_worker_pool = PushTaskWorkerPool(
-                self, self.push_task_rmq_topic, self.push_task_consumer, self.push_task_producer)
+                self, self.push_task_rmq_topic, self.push_generate_task_consumer,
+                self.push_send_task_consumer, self.push_task_producer
+            )
             push_task_worker_pool.start()
             push_task_worker_pool.wait_to_finish()
             self.next_push_disabled = False
@@ -427,7 +440,8 @@ class AgentService:
             push_scan_threads.append(scan_thread)
 
         push_task_worker_pool = PushTaskWorkerPool(
-            self, self.push_task_rmq_topic, self.push_task_consumer, self.push_task_producer)
+            self, self.push_task_rmq_topic,
+            self.push_generate_task_consumer, self.push_send_task_consumer, self.push_task_producer)
         push_task_worker_pool.start()
         for thread in push_scan_threads:
             thread.join()

+ 24 - 12
pqai_agent/push_service.py

@@ -63,7 +63,7 @@ class PushScanThread:
             #     logger.debug(f"User {user_id} not enabled agent push, skipping.")
             #     continue
             user_tags = self.service.user_relation_manager.get_user_tags(user_id)
-            if configs.get_env() != 'dev' and not white_list_tags.intersection(user_tags):
+            if not white_list_tags.intersection(user_tags):
                 should_initiate = False
             else:
                 agent = self.service.get_agent_instance(staff_id, user_id)
@@ -78,14 +78,17 @@ class PushScanThread:
 
 class PushTaskWorkerPool:
     def __init__(self, agent_service: 'AgentService', mq_topic: str,
-                 mq_consumer: rocketmq.SimpleConsumer, mq_producer: rocketmq.Producer):
+                 mq_consumer_generate: rocketmq.SimpleConsumer,
+                 mq_consumer_send: rocketmq.SimpleConsumer,
+                 mq_producer: rocketmq.Producer):
         self.agent_service = agent_service
         max_workers = configs.get()['system'].get('push_task_workers', 5)
         self.max_push_workers = max_workers
         self.generate_executor = ThreadPoolExecutor(max_workers=max_workers)
         self.send_executors = {}
         self.rmq_topic = mq_topic
-        self.consumer = mq_consumer
+        self.generate_consumer = mq_consumer_generate
+        self.send_consumer = mq_consumer_send
         self.producer = mq_producer
         self.loop_thread = None
         self.is_generator_running = True
@@ -100,7 +103,13 @@ class PushTaskWorkerPool:
         # RMQ consumer疑似有bug,创建后立即消费可能报NPE
         time.sleep(1)
         while True:
-            msgs = self.consumer.receive(1, 300)
+            # FIXME: 拆分为两个单独的线程
+            # 目前优先处理发送任务
+            task_source = 'send'
+            msgs = self.send_consumer.receive(1, 60)
+            if not msgs:
+                task_source = 'generate'
+                msgs = self.generate_consumer.receive(1, 300)
             if not msgs:
                 # 没有生成任务在执行且没有消息,才可退出
                 if self.generate_send_done:
@@ -136,7 +145,10 @@ class PushTaskWorkerPool:
                 self.send_executors[staff_id].submit(self.handle_send_task, task, msg)
             else:
                 logger.error(f"Unknown task type: {task['task_type']}")
-                self.consumer.ack(msg)
+                if task_source == 'send':
+                    self.send_consumer.ack(msg)
+                else:
+                    self.generate_consumer.ack(msg)
         logger.info("PushGenerateWorkerPool stopped")
 
     def wait_to_finish(self):
@@ -155,13 +167,13 @@ class PushTaskWorkerPool:
             agent = self.agent_service.get_agent_instance(staff_id, user_id)
             # 二次校验是否需要发送
             if not agent.should_initiate_conversation():
-                logger.debug(f"user[{user_id}], do not initiate conversation")
-                self.consumer.ack(msg)
+                logger.debug(f"user[{user_id}], should not initiate, skip sending task")
+                self.send_consumer.ack(msg)
                 return
             contents: List[Dict] = json.loads(task['content'])
             if not contents:
                 logger.debug(f"staff[{staff_id}], user[{user_id}]: empty content, do not send")
-                self.consumer.ack(msg)
+                self.send_consumer.ack(msg)
                 return
             recent_dialogue = agent.dialogue_history[-10:]
             agent_voice_whitelist = set(apollo_config.get_json_value("agent_voice_whitelist", []))
@@ -197,11 +209,11 @@ class PushTaskWorkerPool:
                 agent.update_last_active_interaction_time(current_ts)
             else:
                 logger.debug(f"staff[{staff_id}], user[{user_id}]: generate empty response")
-            self.consumer.ack(msg)
+            self.send_consumer.ack(msg)
         except Exception as e:
             fmt_exc = traceback.format_exc()
             logger.error(f"Error processing message sending: {e}, {fmt_exc}")
-            self.consumer.ack(msg)
+            self.send_consumer.ack(msg)
 
     def handle_generate_task(self, task: Dict, msg: rocketmq.Message):
         try:
@@ -237,9 +249,9 @@ class PushTaskWorkerPool:
                 self.producer.send(rmq_message)
             else:
                 logger.info(f"staff[{staff_id}], user[{user_id}]: no push message generated")
-            self.consumer.ack(msg)
+            self.generate_consumer.ack(msg)
         except Exception as e:
             fmt_exc = traceback.format_exc()
             logger.error(f"Error processing message generation: {e}, {fmt_exc}")
             # FIXME: 是否需要ACK
-            self.consumer.ack(msg)
+            self.generate_consumer.ack(msg)