123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181 |
- import json
- import time
- import traceback
- import uuid
- from datetime import datetime
- from enum import Enum
- from concurrent.futures import ThreadPoolExecutor
- from threading import Thread
- from typing import Optional, Dict
- import rocketmq
- from rocketmq import ClientConfiguration, Credentials, SimpleConsumer, FilterExpression
- from pqai_agent import configs
- from pqai_agent.agents.message_push_agent import MessagePushAgent, DummyMessagePushAgent
- from pqai_agent.configs import apollo_config
- from pqai_agent.logging_service import logger
- from pqai_agent.message import MessageType
- class TaskType(Enum):
- GENERATE = "generate"
- SEND = "send"
- def generate_task_rmq_message(topic: str, staff_id: str, user_id: str, task_type: TaskType, content: Optional[str] = None) -> rocketmq.Message:
- msg = rocketmq.Message()
- msg.topic = topic
- msg.body = json.dumps({
- 'staff_id': staff_id,
- 'user_id': user_id,
- 'task_type': task_type.value,
- # FIXME: 需要支持多模态消息
- 'content': content or '',
- 'timestamp': int(time.time() * 1000),
- }, ensure_ascii=False).encode('utf-8')
- msg.tag = task_type.value
- return msg
- class PushScanThread:
- # PushScanThread实际可以是AgentService的一个函数,从AgentService中独立的主要考虑因素为Push后续可能有拆分和扩展
- def __init__(self, staff_id: str, agent_service: 'AgentService', mq_topic: str, mq_producer: rocketmq.Producer):
- self.staff_id = staff_id
- # 需要大量使用AgentService内部的成员
- self.service = agent_service
- self.rmq_topic = mq_topic
- self.rmq_producer = mq_producer
- def run(self):
- white_list_tags = set(apollo_config.get_json_value('agent_initiate_whitelist_tags', []))
- first_initiate_tags = set(apollo_config.get_json_value('agent_first_initiate_whitelist_tags', []))
- # 合并白名单,减少配置成本
- white_list_tags.update(first_initiate_tags)
- for staff_user in self.service.user_relation_manager.list_staff_users(staff_id=self.staff_id):
- staff_id = staff_user['staff_id']
- user_id = staff_user['user_id']
- agent = self.service.get_agent_instance(staff_id, user_id)
- should_initiate = agent.should_initiate_conversation()
- user_tags = self.service.user_relation_manager.get_user_tags(user_id)
- if configs.get_env() != 'dev' and not white_list_tags.intersection(user_tags):
- should_initiate = False
- if should_initiate:
- logger.info(f"user[{user_id}], tags{user_tags}: generate a generation task for conversation initiation")
- rmq_msg = generate_task_rmq_message(self.rmq_topic, staff_id, user_id, TaskType.GENERATE)
- self.rmq_producer.send(rmq_msg)
- else:
- logger.debug(f"user[{user_id}], do not initiate conversation")
- class PushTaskWorkerPool:
- def __init__(self, agent_service: 'AgentService', mq_topic: str,
- mq_consumer: rocketmq.SimpleConsumer, mq_producer: rocketmq.Producer):
- self.agent_service = agent_service
- self.generate_executor = ThreadPoolExecutor(max_workers=5)
- self.send_executors = {}
- self.rmq_topic = mq_topic
- self.consumer = mq_consumer
- self.producer = mq_producer
- self.loop_thread = None
- self.is_generator_running = True
- self.generate_send_done = False # set by wait_to_finish
- self.no_more_generate_task = False # set by self
- def start(self):
- self.loop_thread = Thread(target=self.process_push_tasks)
- self.loop_thread.start()
- def process_push_tasks(self):
- # RMQ consumer疑似有bug,创建后立即消费可能报NPE
- time.sleep(1)
- while True:
- msgs = self.consumer.receive(1, 300)
- if not msgs:
- # 没有生成任务在执行且没有消息,才可退出
- if self.generate_send_done:
- if not self.no_more_generate_task:
- logger.debug("no message received, there should be no more generate task")
- self.no_more_generate_task = True
- continue
- else:
- if self.is_generator_running:
- logger.debug("Waiting for generator threads to finish")
- continue
- else:
- break
- else:
- continue
- msg = msgs[0]
- task = json.loads(msg.body.decode('utf-8'))
- msg_time = datetime.fromtimestamp(task['timestamp'] / 1000).strftime("%Y-%m-%d %H:%M:%S")
- logger.debug(f"recv message:{msg_time} - {task}")
- if task['task_type'] == TaskType.GENERATE.value:
- self.generate_executor.submit(self.handle_generate_task, task, msg)
- elif task['task_type'] == TaskType.SEND.value:
- staff_id = task['staff_id']
- if staff_id not in self.send_executors:
- self.send_executors[staff_id] = ThreadPoolExecutor(max_workers=1)
- self.send_executors[staff_id].submit(self.handle_send_task, task, msg)
- else:
- logger.error(f"Unknown task type: {task['task_type']}")
- self.consumer.ack(msg)
- logger.info("PushGenerateWorkerPool stopped")
- def wait_to_finish(self):
- self.generate_send_done = True
- while not self.no_more_generate_task:
- #FIXME(zhoutian): condition variable should be used to replace time sleep
- time.sleep(1)
- self.generate_executor.shutdown(wait=True)
- self.is_generator_running = False
- self.loop_thread.join()
- def handle_send_task(self, task: Dict, msg: rocketmq.Message):
- try:
- staff_id = task['staff_id']
- user_id = task['user_id']
- agent = self.agent_service.get_agent_instance(staff_id, user_id)
- # 二次校验是否需要发送
- if not agent.should_initiate_conversation():
- logger.debug(f"user[{user_id}], do not initiate conversation")
- self.consumer.ack(msg)
- return
- content = task['content']
- recent_dialogue = agent.dialogue_history[-10:]
- agent_voice_whitelist = set(apollo_config.get_json_value("agent_voice_whitelist"))
- # FIXME(zhoutian): 不应该再由agent控制,或者agent和API共享同一配置
- if len(recent_dialogue) < 2 or staff_id not in agent_voice_whitelist:
- message_type = MessageType.TEXT
- else:
- message_type = self.agent_service.response_type_detector.detect_type(
- recent_dialogue, content, enable_random=True)
- response = agent.generate_response(content)
- if response:
- self.agent_service.send_response(staff_id, user_id, response, message_type, skip_check=True)
- else:
- logger.debug(f"agent[{staff_id}] generate empty response")
- self.consumer.ack(msg)
- except Exception as e:
- fmt_exc = traceback.format_exc()
- logger.error(f"Error processing message sending: {e}, {fmt_exc}")
- self.consumer.ack(msg)
- def handle_generate_task(self, task: Dict, msg: rocketmq.Message):
- try:
- staff_id = task['staff_id']
- user_id = task['user_id']
- main_agent = self.agent_service.get_agent_instance(staff_id, user_id)
- push_agent = MessagePushAgent()
- message_to_user = push_agent.generate_message(
- context=main_agent.get_prompt_context(None),
- dialogue_history=main_agent.dialogue_history
- )
- rmq_message = generate_task_rmq_message(self.rmq_topic, staff_id, user_id, TaskType.SEND, message_to_user)
- logger.debug(f"send message: {rmq_message.body.decode('utf-8')}")
- self.producer.send(rmq_message)
- self.consumer.ack(msg)
- except Exception as e:
- fmt_exc = traceback.format_exc()
- logger.error(f"Error processing message generation: {e}, {fmt_exc}")
- # FIXME: 是否需要ACK
- self.consumer.ack(msg)
|