import json import time import traceback import uuid from datetime import datetime from enum import Enum from concurrent.futures import ThreadPoolExecutor from threading import Thread from typing import Optional, Dict, List import rocketmq from rocketmq import ClientConfiguration, Credentials, SimpleConsumer, FilterExpression from pqai_agent import configs from pqai_agent.abtest.utils import get_abtest_info from pqai_agent.agents.message_push_agent import MessagePushAgent, DummyMessagePushAgent from pqai_agent.configs import apollo_config from pqai_agent.data_models.agent_push_record import AgentPushRecord from pqai_agent.logging_service import logger from pqai_agent.mq_message import MessageType from pqai_agent.toolkit import get_tools from pqai_agent.utils.agent_abtest_utils import get_agent_abtest_config class TaskType(Enum): GENERATE = "generate" SEND = "send" def generate_task_rmq_message(topic: str, staff_id: str, user_id: str, task_type: TaskType, content: Optional[str] = None) -> rocketmq.Message: msg = rocketmq.Message() msg.topic = topic msg.body = json.dumps({ 'staff_id': staff_id, 'user_id': user_id, 'task_type': task_type.value, # NOTE:通过传入JSON支持多模态消息 'content': content or '', 'timestamp': int(time.time() * 1000), }, ensure_ascii=False).encode('utf-8') msg.tag = task_type.value return msg class PushScanThread: # PushScanThread实际可以是AgentService的一个函数,从AgentService中独立的主要考虑因素为Push后续可能有拆分和扩展 def __init__(self, staff_id: str, agent_service: 'AgentService', mq_topic: str, mq_producer: rocketmq.Producer): self.staff_id = staff_id # 需要大量使用AgentService内部的成员 self.service = agent_service self.rmq_topic = mq_topic self.rmq_producer = mq_producer def run(self): white_list_tags = set(apollo_config.get_json_value('agent_initiate_whitelist_tags', [])) first_initiate_tags = set(apollo_config.get_json_value('agent_first_initiate_whitelist_tags', [])) # 合并白名单,减少配置成本 white_list_tags.update(first_initiate_tags) for staff_user in self.service.user_relation_manager.list_staff_users(staff_id=self.staff_id): staff_id = staff_user['staff_id'] user_id = staff_user['user_id'] # 通过AB实验配置控制用户组是否启用push # abtest_params = get_abtest_info(user_id).params # if abtest_params.get('agent_push_enabled', 'false').lower() != 'true': # logger.debug(f"User {user_id} not enabled agent push, skipping.") # continue user_tags = self.service.user_relation_manager.get_user_tags(user_id) if configs.get_env() != 'dev' and not white_list_tags.intersection(user_tags): should_initiate = False else: agent = self.service.get_agent_instance(staff_id, user_id) should_initiate = agent.should_initiate_conversation() if should_initiate: logger.info(f"user[{user_id}], tags{user_tags}: generate a generation task for conversation initiation") rmq_msg = generate_task_rmq_message(self.rmq_topic, staff_id, user_id, TaskType.GENERATE) self.rmq_producer.send(rmq_msg) else: logger.debug(f"user[{user_id}], do not initiate conversation") class PushTaskWorkerPool: def __init__(self, agent_service: 'AgentService', mq_topic: str, mq_consumer: rocketmq.SimpleConsumer, mq_producer: rocketmq.Producer): self.agent_service = agent_service max_workers = configs.get()['system'].get('push_task_workers', 5) self.max_push_workers = max_workers self.generate_executor = ThreadPoolExecutor(max_workers=max_workers) self.send_executors = {} self.rmq_topic = mq_topic self.consumer = mq_consumer self.producer = mq_producer self.loop_thread = None self.is_generator_running = True self.generate_send_done = False # set by wait_to_finish self.no_more_generate_task = False # set by self def start(self): self.loop_thread = Thread(target=self.process_push_tasks) self.loop_thread.start() def process_push_tasks(self): # RMQ consumer疑似有bug,创建后立即消费可能报NPE time.sleep(1) while True: msgs = self.consumer.receive(1, 300) if not msgs: # 没有生成任务在执行且没有消息,才可退出 if self.generate_send_done: if not self.no_more_generate_task: logger.debug("no message received, there should be no more generate task") self.no_more_generate_task = True continue else: if self.is_generator_running: logger.debug("Waiting for generator threads to finish") continue else: break else: continue msg = msgs[0] task = json.loads(msg.body.decode('utf-8')) msg_time = datetime.fromtimestamp(task['timestamp'] / 1000).strftime("%Y-%m-%d %H:%M:%S") logger.debug(f"recv message:{msg_time} - {task}") if task['task_type'] == TaskType.GENERATE.value: # FIXME: 临时方案,避免消息在消费后等待超时并重复消费 if self.generate_executor._work_queue.qsize() > self.max_push_workers * 5: logger.warning("Too many generate tasks in queue, consume this task later") while self.generate_executor._work_queue.qsize() > self.max_push_workers * 5: time.sleep(10) # do not submit and ack this message continue self.generate_executor.submit(self.handle_generate_task, task, msg) elif task['task_type'] == TaskType.SEND.value: staff_id = task['staff_id'] if staff_id not in self.send_executors: self.send_executors[staff_id] = ThreadPoolExecutor(max_workers=1) self.send_executors[staff_id].submit(self.handle_send_task, task, msg) else: logger.error(f"Unknown task type: {task['task_type']}") self.consumer.ack(msg) logger.info("PushGenerateWorkerPool stopped") def wait_to_finish(self): self.generate_send_done = True while not self.no_more_generate_task: #FIXME(zhoutian): condition variable should be used to replace time sleep time.sleep(1) self.generate_executor.shutdown(wait=True) self.is_generator_running = False self.loop_thread.join() def handle_send_task(self, task: Dict, msg: rocketmq.Message): try: staff_id = task['staff_id'] user_id = task['user_id'] agent = self.agent_service.get_agent_instance(staff_id, user_id) # 二次校验是否需要发送 if not agent.should_initiate_conversation(): logger.debug(f"user[{user_id}], do not initiate conversation") self.consumer.ack(msg) return contents: List[Dict] = json.loads(task['content']) if not contents: logger.debug(f"staff[{staff_id}], user[{user_id}]: empty content, do not send") self.consumer.ack(msg) return recent_dialogue = agent.dialogue_history[-10:] agent_voice_whitelist = set(apollo_config.get_json_value("agent_voice_whitelist", [])) messages_to_send = [] current_ts = int(time.time()) for item in contents: item["timestamp"] = current_ts * 1000 if item["type"] == "text": if staff_id not in agent_voice_whitelist: message_type = MessageType.TEXT else: message_type = self.agent_service.response_type_detector.detect_type( recent_dialogue, item, enable_random=True) response = agent.generate_response(item["content"]) if response: messages_to_send.append({'type': message_type, 'content': response}) else: message_type = MessageType.from_str(item["type"]) response = agent.generate_multimodal_response(item) if response: item["type"] = message_type messages_to_send.append(item) with self.agent_service.agent_db_session_maker() as session: msg_list = [{"type": msg["type"].value, "content": msg["content"]} for msg in messages_to_send] record = AgentPushRecord(staff_id=staff_id, user_id=user_id, content=json.dumps(msg_list, ensure_ascii=False), timestamp=current_ts) session.add(record) session.commit() if messages_to_send: for response in messages_to_send: self.agent_service.send_multimodal_response(staff_id, user_id, response, skip_check=True) agent.update_last_active_interaction_time(current_ts) else: logger.debug(f"staff[{staff_id}], user[{user_id}]: generate empty response") self.consumer.ack(msg) except Exception as e: fmt_exc = traceback.format_exc() logger.error(f"Error processing message sending: {e}, {fmt_exc}") self.consumer.ack(msg) def handle_generate_task(self, task: Dict, msg: rocketmq.Message): try: staff_id = task['staff_id'] user_id = task['user_id'] main_agent = self.agent_service.get_agent_instance(staff_id, user_id) agent_config = get_agent_abtest_config('push', user_id, self.agent_service.service_module_manager, self.agent_service.agent_config_manager) if agent_config: try: tool_names = json.loads(agent_config.tools) except json.JSONDecodeError: logger.error(f"Invalid JSON in agent tools: {agent_config.tools}") tool_names = [] push_agent = MessagePushAgent(model=agent_config.execution_model, system_prompt=agent_config.system_prompt, tools=get_tools(tool_names)) query_prompt_template = agent_config.task_prompt else: push_agent = MessagePushAgent() query_prompt_template = None message_to_user = push_agent.generate_message( context=main_agent.get_prompt_context(None), dialogue_history=self.agent_service.history_dialogue_db.get_dialogue_history_backward( staff_id, user_id, main_agent.last_interaction_time_ms, limit=30 ), query_prompt_template=query_prompt_template ) if message_to_user: rmq_message = generate_task_rmq_message( self.rmq_topic, staff_id, user_id, TaskType.SEND, json.dumps(message_to_user)) self.producer.send(rmq_message) else: logger.info(f"staff[{staff_id}], user[{user_id}]: no push message generated") self.consumer.ack(msg) except Exception as e: fmt_exc = traceback.format_exc() logger.error(f"Error processing message generation: {e}, {fmt_exc}") # FIXME: 是否需要ACK self.consumer.ack(msg)