소스 검색

Add push_service

StrayWarrior 1 개월 전
부모
커밋
d62e1c4a00
1개의 변경된 파일173개의 추가작업 그리고 0개의 파일을 삭제
  1. 173 0
      pqai_agent/push_service.py

+ 173 - 0
pqai_agent/push_service.py

@@ -0,0 +1,173 @@
+import json
+import time
+import uuid
+from enum import Enum
+from concurrent.futures import ThreadPoolExecutor
+from threading import Thread
+from typing import Optional, Dict
+
+import rocketmq
+from rocketmq import ClientConfiguration, Credentials, SimpleConsumer, FilterExpression
+
+from pqai_agent import configs
+from pqai_agent.agents.message_push_agent import MessagePushAgent, DummyMessagePushAgent
+from pqai_agent.configs import apollo_config
+from pqai_agent.logging_service import logger
+from pqai_agent.message import MessageType
+
+
+class TaskType(Enum):
+    GENERATE = "generate"
+    SEND = "send"
+
+def generate_task_rmq_message(topic: str, staff_id: str, user_id: str, task_type: TaskType, content: Optional[str] = None) -> rocketmq.Message:
+    msg = rocketmq.Message()
+    msg.topic = topic
+    msg.body = json.dumps({
+        'staff_id': staff_id,
+        'user_id': user_id,
+        'task_type': task_type.value,
+        # FIXME: 需要支持多模态消息
+        'content': content or '',
+        'timestamp': int(time.time() * 1000),
+    }, ensure_ascii=False).encode('utf-8')
+    msg.tag = task_type.value
+    return msg
+
+class PushScanThread:
+    # PushScanThread实际可以是AgentService的一个函数,从AgentService中独立的主要考虑因素为Push后续可能有拆分和扩展
+    def __init__(self, staff_id: str, agent_service: 'AgentService', mq_topic: str, mq_producer: rocketmq.Producer):
+        self.staff_id = staff_id
+        # 需要大量使用AgentService内部的成员
+        self.service = agent_service
+        self.rmq_topic = mq_topic
+        self.rmq_producer = mq_producer
+
+    def run(self):
+        white_list_tags = set(apollo_config.get_json_value('agent_initiate_whitelist_tags'))
+        first_initiate_tags = set(apollo_config.get_json_value('agent_first_initiate_whitelist_tags', []))
+        # 合并白名单,减少配置成本
+        white_list_tags.update(first_initiate_tags)
+        for staff_user in self.service.user_relation_manager.list_staff_users(staff_id=self.staff_id):
+            staff_id = staff_user['staff_id']
+            user_id = staff_user['user_id']
+            agent = self.service.get_agent_instance(staff_id, user_id)
+            should_initiate = agent.should_initiate_conversation()
+            user_tags = self.service.user_relation_manager.get_user_tags(user_id)
+
+            if configs.get_env() != 'dev' and not white_list_tags.intersection(user_tags):
+                should_initiate = False
+            if should_initiate:
+                logger.info(f"user[{user_id}], tags{user_tags}: generate a generation task for conversation initiation")
+                rmq_msg = generate_task_rmq_message(self.rmq_topic, staff_id, user_id, TaskType.GENERATE)
+                self.rmq_producer.send(rmq_msg)
+            else:
+                logger.debug(f"user[{user_id}], do not initiate conversation")
+
+
+class PushTaskWorkerPool:
+    def __init__(self, agent_service: 'AgentService', mq_topic: str,
+                 mq_consumer: rocketmq.SimpleConsumer, mq_producer: rocketmq.Producer):
+        self.agent_service = agent_service
+        self.generate_executor = ThreadPoolExecutor(max_workers=5)
+        self.send_executors = {}
+        self.rmq_topic = mq_topic
+        self.consumer = mq_consumer
+        self.producer = mq_producer
+        self.loop_thread = None
+        self.is_generator_running = True
+        self.generate_send_done = False # set by wait_to_finish
+        self.no_more_generate_task = False # set by self
+
+    def start(self):
+        self.loop_thread = Thread(target=self.process_push_tasks)
+        self.loop_thread.start()
+
+    def process_push_tasks(self):
+        while True:
+            msgs = self.consumer.receive(1, 300)
+            if not msgs:
+                # 没有生成任务在执行且没有消息,才可退出
+                if self.generate_send_done:
+                    if not self.no_more_generate_task:
+                        logger.debug("no message received, there should be no more generate task")
+                        self.no_more_generate_task = True
+                        continue
+                    else:
+                        if self.is_generator_running:
+                            logger.debug("Waiting for generator threads to finish")
+                            continue
+                        else:
+                            break
+                else:
+                    continue
+            msg = msgs[0]
+            task = json.loads(msg.body.decode('utf-8'))
+            logger.debug(f"recv message: {task}")
+            if task['task_type'] == TaskType.GENERATE.value:
+                self.generate_executor.submit(self.handle_generate_task, task, msg)
+            elif task['task_type'] == TaskType.SEND.value:
+                staff_id = task['staff_id']
+                if staff_id not in self.send_executors:
+                    self.send_executors[staff_id] = ThreadPoolExecutor(max_workers=1)
+                self.send_executors[staff_id].submit(self.handle_send_task, task, msg)
+            else:
+                logger.error(f"Unknown task type: {task['task_type']}")
+                self.consumer.ack(msg)
+        logger.info("PushGenerateWorkerPool stopped")
+
+    def wait_to_finish(self):
+        self.generate_send_done = True
+        while not self.no_more_generate_task:
+            #FIXME(zhoutian): condition variable should be used to replace time sleep
+            time.sleep(1)
+        self.generate_executor.shutdown(wait=True)
+        self.is_generator_running = False
+        self.loop_thread.join()
+
+    def handle_send_task(self, task: Dict, msg: rocketmq.Message):
+        try:
+            staff_id = task['staff_id']
+            user_id = task['user_id']
+            agent = self.agent_service.get_agent_instance(staff_id, user_id)
+            # 二次校验是否需要发送
+            if not agent.should_initiate_conversation():
+                logger.debug(f"user[{user_id}], do not initiate conversation")
+                self.consumer.ack(msg)
+                return
+            content = task['content']
+            recent_dialogue = agent.dialogue_history[-10:]
+            agent_voice_whitelist = set(apollo_config.get_json_value("agent_voice_whitelist"))
+            if len(recent_dialogue) < 2 or staff_id not in agent_voice_whitelist:
+                message_type = MessageType.TEXT
+            else:
+                message_type = self.agent_service.response_type_detector.detect_type(
+                    recent_dialogue[:-1], recent_dialogue[-1], enable_random=True)
+            response = agent.generate_response(content)
+            if response:
+                self.agent_service.send_response(staff_id, user_id, response, message_type, skip_check=True)
+            else:
+                logger.debug(f"agent[{staff_id}] generate empty response")
+            self.consumer.ack(msg)
+        except Exception as e:
+            logger.error(f"Error processing message sending: {e}")
+            self.consumer.ack(msg)
+
+    def handle_generate_task(self, task: Dict, msg: rocketmq.Message):
+        try:
+            staff_id = task['staff_id']
+            user_id = task['user_id']
+            main_agent = self.agent_service.get_agent_instance(staff_id, user_id)
+            push_agent = DummyMessagePushAgent()
+            message_to_user = push_agent.generate_message(
+                context=main_agent.get_prompt_context(None),
+                dialogue_history=main_agent.dialogue_history
+            )
+            rmq_message = generate_task_rmq_message(self.rmq_topic, staff_id, user_id, TaskType.SEND, message_to_user)
+            logger.debug(f"send message: {rmq_message.body.decode('utf-8')}")
+            self.producer.send(rmq_message)
+            self.consumer.ack(msg)
+        except Exception as e:
+            logger.error(f"Error processing message generation: {e}")
+            # FIXME: 是否需要ACK
+            self.consumer.ack(msg)