|  | @@ -3,9 +3,16 @@
 | 
	
		
			
				|  |  |  # vim:fenc=utf-8
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import abc
 | 
	
		
			
				|  |  | +import time
 | 
	
		
			
				|  |  | +import logging
 | 
	
		
			
				|  |  |  from typing import Dict, Any, Optional
 | 
	
		
			
				|  |  | +import configs
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -from message import Message
 | 
	
		
			
				|  |  | +import logging_service
 | 
	
		
			
				|  |  | +from message import Message, MessageType, MessageChannel
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +import rocketmq
 | 
	
		
			
				|  |  | +from rocketmq import ClientConfiguration, Credentials, SimpleConsumer
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class MessageQueueBackend(abc.ABC):
 | 
	
	
		
			
				|  | @@ -13,10 +20,18 @@ class MessageQueueBackend(abc.ABC):
 | 
	
		
			
				|  |  |      def consume(self) -> Optional[Message]:
 | 
	
		
			
				|  |  |          pass
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    @abc.abstractmethod
 | 
	
		
			
				|  |  | +    def ack(self, message: Message) -> None:
 | 
	
		
			
				|  |  | +        pass
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      @abc.abstractmethod
 | 
	
		
			
				|  |  |      def produce(self, message: Message) -> None:
 | 
	
		
			
				|  |  |          pass
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    @abc.abstractmethod
 | 
	
		
			
				|  |  | +    def shutdown(self):
 | 
	
		
			
				|  |  | +        pass
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  class MemoryQueueBackend(MessageQueueBackend):
 | 
	
		
			
				|  |  |      """内存消息队列实现"""
 | 
	
		
			
				|  |  |      def __init__(self):
 | 
	
	
		
			
				|  | @@ -25,5 +40,118 @@ class MemoryQueueBackend(MessageQueueBackend):
 | 
	
		
			
				|  |  |      def consume(self) -> Optional[Message]:
 | 
	
		
			
				|  |  |          return self._queue.pop(0) if self._queue else None
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    def ack(self, message: Message):
 | 
	
		
			
				|  |  | +        return
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      def produce(self, message: Message):
 | 
	
		
			
				|  |  | -        self._queue.append(message)
 | 
	
		
			
				|  |  | +        self._queue.append(message)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def shutdown(self):
 | 
	
		
			
				|  |  | +        pass
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +class AliyunRocketMQQueueBackend(MessageQueueBackend):
 | 
	
		
			
				|  |  | +    def __init__(self, endpoints: str, instance_id: str, topic: str,
 | 
	
		
			
				|  |  | +                 has_consumer: bool = False, has_producer: bool = False,
 | 
	
		
			
				|  |  | +                 group_id: Optional[str] = None,
 | 
	
		
			
				|  |  | +                 ak:Optional[str] = None, sk: Optional[str] = None):
 | 
	
		
			
				|  |  | +        if not has_consumer and not has_producer:
 | 
	
		
			
				|  |  | +            raise ValueError("At least one of has_consumer or has_producer must be True.")
 | 
	
		
			
				|  |  | +        self.has_consumer = has_consumer
 | 
	
		
			
				|  |  | +        self.has_producer = has_producer
 | 
	
		
			
				|  |  | +        credentials = Credentials()
 | 
	
		
			
				|  |  | +        # credentials = Credentials("ak", "sk")
 | 
	
		
			
				|  |  | +        mq_config = ClientConfiguration(endpoints, credentials, instance_id)
 | 
	
		
			
				|  |  | +        self.topic = topic
 | 
	
		
			
				|  |  | +        self.group_id = group_id
 | 
	
		
			
				|  |  | +        if has_consumer:
 | 
	
		
			
				|  |  | +            self.consumer = SimpleConsumer(mq_config, group_id)
 | 
	
		
			
				|  |  | +            self.consumer.startup()
 | 
	
		
			
				|  |  | +            self.consumer.subscribe(self.topic)
 | 
	
		
			
				|  |  | +        if has_producer:
 | 
	
		
			
				|  |  | +            self.producer = rocketmq.Producer(mq_config, (topic,))
 | 
	
		
			
				|  |  | +            self.producer.startup()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def __del__(self):
 | 
	
		
			
				|  |  | +        self.shutdown()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def consume(self) -> Optional[Message]:
 | 
	
		
			
				|  |  | +        if not self.has_consumer:
 | 
	
		
			
				|  |  | +            raise Exception("Consumer not initialized.")
 | 
	
		
			
				|  |  | +        messages = self.consumer.receive(1, 10)
 | 
	
		
			
				|  |  | +        if not messages:
 | 
	
		
			
				|  |  | +            return None
 | 
	
		
			
				|  |  | +        rmq_message = messages[0]
 | 
	
		
			
				|  |  | +        body = rmq_message.body.decode('utf-8')
 | 
	
		
			
				|  |  | +        logging.debug("recv message body: {}".format(body))
 | 
	
		
			
				|  |  | +        try:
 | 
	
		
			
				|  |  | +            message = Message.from_json(body)
 | 
	
		
			
				|  |  | +            message._rmq_message = rmq_message
 | 
	
		
			
				|  |  | +        except Exception as e:
 | 
	
		
			
				|  |  | +            logging.error("Invalid message: {}. Parsing error: {}".format(body, e))
 | 
	
		
			
				|  |  | +            # 如果消息非法,直接ACK,避免死信
 | 
	
		
			
				|  |  | +            self.consumer.ack(rmq_message)
 | 
	
		
			
				|  |  | +            return None
 | 
	
		
			
				|  |  | +        return message
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def ack(self, message: Message):
 | 
	
		
			
				|  |  | +        if not message._rmq_message:
 | 
	
		
			
				|  |  | +            raise ValueError("Message not set with _rmq_message.")
 | 
	
		
			
				|  |  | +        logging.debug("ack message: {}".format(message))
 | 
	
		
			
				|  |  | +        self.consumer.ack(message._rmq_message)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def produce(self, message: Message) -> None:
 | 
	
		
			
				|  |  | +        if not self.has_producer:
 | 
	
		
			
				|  |  | +            raise Exception("Producer not initialized.")
 | 
	
		
			
				|  |  | +        message.model_config['use_enum_values'] = False
 | 
	
		
			
				|  |  | +        json_str = message.to_json()
 | 
	
		
			
				|  |  | +        rmq_message = rocketmq.Message()
 | 
	
		
			
				|  |  | +        rmq_message.topic = self.topic
 | 
	
		
			
				|  |  | +        rmq_message.body = json_str.encode('utf-8')
 | 
	
		
			
				|  |  | +        # 顺序消息队列必须指定消息组
 | 
	
		
			
				|  |  | +        rmq_message.message_group = "agent_system"
 | 
	
		
			
				|  |  | +        self.producer.send(rmq_message)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def shutdown(self):
 | 
	
		
			
				|  |  | +        if self.has_consumer:
 | 
	
		
			
				|  |  | +            self.consumer.shutdown()
 | 
	
		
			
				|  |  | +        if self.has_producer:
 | 
	
		
			
				|  |  | +            self.producer.shutdown()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +if __name__ == '__main__':
 | 
	
		
			
				|  |  | +    logging_service.setup_root_logger()
 | 
	
		
			
				|  |  | +    config = configs.get()
 | 
	
		
			
				|  |  | +    # test Aliyun RocketMQ
 | 
	
		
			
				|  |  | +    endpoints = config['mq']['endpoints']
 | 
	
		
			
				|  |  | +    instance_id = config['mq']['instance_id']
 | 
	
		
			
				|  |  | +    topic = config['mq']['receive_topic']
 | 
	
		
			
				|  |  | +    group_id = config['mq']['receive_group']
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    queue = AliyunRocketMQQueueBackend(
 | 
	
		
			
				|  |  | +        endpoints, instance_id, topic, True, True, group_id)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    while True:
 | 
	
		
			
				|  |  | +        recv_message = queue.consume()
 | 
	
		
			
				|  |  | +        print(recv_message)
 | 
	
		
			
				|  |  | +        if recv_message:
 | 
	
		
			
				|  |  | +            queue.ack(recv_message)
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            break
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    send_message = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.CORP_WECHAT,
 | 
	
		
			
				|  |  | +                                 "user_id_1", "staff_id_0",
 | 
	
		
			
				|  |  | +                                 None, int(time.time() * 1000))
 | 
	
		
			
				|  |  | +    queue.produce(send_message)
 | 
	
		
			
				|  |  | +    recv_message = queue.consume()
 | 
	
		
			
				|  |  | +    print(recv_message)
 | 
	
		
			
				|  |  | +    if recv_message:
 | 
	
		
			
				|  |  | +        queue.ack(recv_message)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    send_message = Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
 | 
	
		
			
				|  |  | +                                 "user_id_1", "staff_id_0",
 | 
	
		
			
				|  |  | +                                 "message_queue_backend test", int(time.time() * 1000))
 | 
	
		
			
				|  |  | +    queue.produce(send_message)
 | 
	
		
			
				|  |  | +    recv_message = queue.consume()
 | 
	
		
			
				|  |  | +    print(recv_message)
 | 
	
		
			
				|  |  | +    if recv_message:
 | 
	
		
			
				|  |  | +        queue.ack(recv_message)
 |