|
@@ -0,0 +1,173 @@
|
|
|
+import json
|
|
|
+import time
|
|
|
+import uuid
|
|
|
+from enum import Enum
|
|
|
+from concurrent.futures import ThreadPoolExecutor
|
|
|
+from threading import Thread
|
|
|
+from typing import Optional, Dict
|
|
|
+
|
|
|
+import rocketmq
|
|
|
+from rocketmq import ClientConfiguration, Credentials, SimpleConsumer, FilterExpression
|
|
|
+
|
|
|
+from pqai_agent import configs
|
|
|
+from pqai_agent.agents.message_push_agent import MessagePushAgent, DummyMessagePushAgent
|
|
|
+from pqai_agent.configs import apollo_config
|
|
|
+from pqai_agent.logging_service import logger
|
|
|
+from pqai_agent.message import MessageType
|
|
|
+
|
|
|
+
|
|
|
+class TaskType(Enum):
|
|
|
+ GENERATE = "generate"
|
|
|
+ SEND = "send"
|
|
|
+
|
|
|
+def generate_task_rmq_message(topic: str, staff_id: str, user_id: str, task_type: TaskType, content: Optional[str] = None) -> rocketmq.Message:
|
|
|
+ msg = rocketmq.Message()
|
|
|
+ msg.topic = topic
|
|
|
+ msg.body = json.dumps({
|
|
|
+ 'staff_id': staff_id,
|
|
|
+ 'user_id': user_id,
|
|
|
+ 'task_type': task_type.value,
|
|
|
+ # FIXME: 需要支持多模态消息
|
|
|
+ 'content': content or '',
|
|
|
+ 'timestamp': int(time.time() * 1000),
|
|
|
+ }, ensure_ascii=False).encode('utf-8')
|
|
|
+ msg.tag = task_type.value
|
|
|
+ return msg
|
|
|
+
|
|
|
+class PushScanThread:
|
|
|
+ # PushScanThread实际可以是AgentService的一个函数,从AgentService中独立的主要考虑因素为Push后续可能有拆分和扩展
|
|
|
+ def __init__(self, staff_id: str, agent_service: 'AgentService', mq_topic: str, mq_producer: rocketmq.Producer):
|
|
|
+ self.staff_id = staff_id
|
|
|
+ # 需要大量使用AgentService内部的成员
|
|
|
+ self.service = agent_service
|
|
|
+ self.rmq_topic = mq_topic
|
|
|
+ self.rmq_producer = mq_producer
|
|
|
+
|
|
|
+ def run(self):
|
|
|
+ white_list_tags = set(apollo_config.get_json_value('agent_initiate_whitelist_tags'))
|
|
|
+ first_initiate_tags = set(apollo_config.get_json_value('agent_first_initiate_whitelist_tags', []))
|
|
|
+ # 合并白名单,减少配置成本
|
|
|
+ white_list_tags.update(first_initiate_tags)
|
|
|
+ for staff_user in self.service.user_relation_manager.list_staff_users(staff_id=self.staff_id):
|
|
|
+ staff_id = staff_user['staff_id']
|
|
|
+ user_id = staff_user['user_id']
|
|
|
+ agent = self.service.get_agent_instance(staff_id, user_id)
|
|
|
+ should_initiate = agent.should_initiate_conversation()
|
|
|
+ user_tags = self.service.user_relation_manager.get_user_tags(user_id)
|
|
|
+
|
|
|
+ if configs.get_env() != 'dev' and not white_list_tags.intersection(user_tags):
|
|
|
+ should_initiate = False
|
|
|
+ if should_initiate:
|
|
|
+ logger.info(f"user[{user_id}], tags{user_tags}: generate a generation task for conversation initiation")
|
|
|
+ rmq_msg = generate_task_rmq_message(self.rmq_topic, staff_id, user_id, TaskType.GENERATE)
|
|
|
+ self.rmq_producer.send(rmq_msg)
|
|
|
+ else:
|
|
|
+ logger.debug(f"user[{user_id}], do not initiate conversation")
|
|
|
+
|
|
|
+
|
|
|
+class PushTaskWorkerPool:
|
|
|
+ def __init__(self, agent_service: 'AgentService', mq_topic: str,
|
|
|
+ mq_consumer: rocketmq.SimpleConsumer, mq_producer: rocketmq.Producer):
|
|
|
+ self.agent_service = agent_service
|
|
|
+ self.generate_executor = ThreadPoolExecutor(max_workers=5)
|
|
|
+ self.send_executors = {}
|
|
|
+ self.rmq_topic = mq_topic
|
|
|
+ self.consumer = mq_consumer
|
|
|
+ self.producer = mq_producer
|
|
|
+ self.loop_thread = None
|
|
|
+ self.is_generator_running = True
|
|
|
+ self.generate_send_done = False # set by wait_to_finish
|
|
|
+ self.no_more_generate_task = False # set by self
|
|
|
+
|
|
|
+ def start(self):
|
|
|
+ self.loop_thread = Thread(target=self.process_push_tasks)
|
|
|
+ self.loop_thread.start()
|
|
|
+
|
|
|
+ def process_push_tasks(self):
|
|
|
+ while True:
|
|
|
+ msgs = self.consumer.receive(1, 300)
|
|
|
+ if not msgs:
|
|
|
+ # 没有生成任务在执行且没有消息,才可退出
|
|
|
+ if self.generate_send_done:
|
|
|
+ if not self.no_more_generate_task:
|
|
|
+ logger.debug("no message received, there should be no more generate task")
|
|
|
+ self.no_more_generate_task = True
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ if self.is_generator_running:
|
|
|
+ logger.debug("Waiting for generator threads to finish")
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ break
|
|
|
+ else:
|
|
|
+ continue
|
|
|
+ msg = msgs[0]
|
|
|
+ task = json.loads(msg.body.decode('utf-8'))
|
|
|
+ logger.debug(f"recv message: {task}")
|
|
|
+ if task['task_type'] == TaskType.GENERATE.value:
|
|
|
+ self.generate_executor.submit(self.handle_generate_task, task, msg)
|
|
|
+ elif task['task_type'] == TaskType.SEND.value:
|
|
|
+ staff_id = task['staff_id']
|
|
|
+ if staff_id not in self.send_executors:
|
|
|
+ self.send_executors[staff_id] = ThreadPoolExecutor(max_workers=1)
|
|
|
+ self.send_executors[staff_id].submit(self.handle_send_task, task, msg)
|
|
|
+ else:
|
|
|
+ logger.error(f"Unknown task type: {task['task_type']}")
|
|
|
+ self.consumer.ack(msg)
|
|
|
+ logger.info("PushGenerateWorkerPool stopped")
|
|
|
+
|
|
|
+ def wait_to_finish(self):
|
|
|
+ self.generate_send_done = True
|
|
|
+ while not self.no_more_generate_task:
|
|
|
+ #FIXME(zhoutian): condition variable should be used to replace time sleep
|
|
|
+ time.sleep(1)
|
|
|
+ self.generate_executor.shutdown(wait=True)
|
|
|
+ self.is_generator_running = False
|
|
|
+ self.loop_thread.join()
|
|
|
+
|
|
|
+ def handle_send_task(self, task: Dict, msg: rocketmq.Message):
|
|
|
+ try:
|
|
|
+ staff_id = task['staff_id']
|
|
|
+ user_id = task['user_id']
|
|
|
+ agent = self.agent_service.get_agent_instance(staff_id, user_id)
|
|
|
+ # 二次校验是否需要发送
|
|
|
+ if not agent.should_initiate_conversation():
|
|
|
+ logger.debug(f"user[{user_id}], do not initiate conversation")
|
|
|
+ self.consumer.ack(msg)
|
|
|
+ return
|
|
|
+ content = task['content']
|
|
|
+ recent_dialogue = agent.dialogue_history[-10:]
|
|
|
+ agent_voice_whitelist = set(apollo_config.get_json_value("agent_voice_whitelist"))
|
|
|
+ if len(recent_dialogue) < 2 or staff_id not in agent_voice_whitelist:
|
|
|
+ message_type = MessageType.TEXT
|
|
|
+ else:
|
|
|
+ message_type = self.agent_service.response_type_detector.detect_type(
|
|
|
+ recent_dialogue[:-1], recent_dialogue[-1], enable_random=True)
|
|
|
+ response = agent.generate_response(content)
|
|
|
+ if response:
|
|
|
+ self.agent_service.send_response(staff_id, user_id, response, message_type, skip_check=True)
|
|
|
+ else:
|
|
|
+ logger.debug(f"agent[{staff_id}] generate empty response")
|
|
|
+ self.consumer.ack(msg)
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error processing message sending: {e}")
|
|
|
+ self.consumer.ack(msg)
|
|
|
+
|
|
|
+ def handle_generate_task(self, task: Dict, msg: rocketmq.Message):
|
|
|
+ try:
|
|
|
+ staff_id = task['staff_id']
|
|
|
+ user_id = task['user_id']
|
|
|
+ main_agent = self.agent_service.get_agent_instance(staff_id, user_id)
|
|
|
+ push_agent = DummyMessagePushAgent()
|
|
|
+ message_to_user = push_agent.generate_message(
|
|
|
+ context=main_agent.get_prompt_context(None),
|
|
|
+ dialogue_history=main_agent.dialogue_history
|
|
|
+ )
|
|
|
+ rmq_message = generate_task_rmq_message(self.rmq_topic, staff_id, user_id, TaskType.SEND, message_to_user)
|
|
|
+ logger.debug(f"send message: {rmq_message.body.decode('utf-8')}")
|
|
|
+ self.producer.send(rmq_message)
|
|
|
+ self.consumer.ack(msg)
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error processing message generation: {e}")
|
|
|
+ # FIXME: 是否需要ACK
|
|
|
+ self.consumer.ack(msg)
|