#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
import abc
import time
from typing import Optional
import rocketmq
from rocketmq import ClientConfiguration, Credentials, SimpleConsumer
from pqai_agent import configs
from pqai_agent import logging_service
from pqai_agent.logging_service import logger
from pqai_agent.message import Message, MessageType, MessageChannel
class MessageQueueBackend(abc.ABC):
@abc.abstractmethod
def consume(self) -> Optional[Message]:
pass
@abc.abstractmethod
def ack(self, message: Message) -> None:
pass
@abc.abstractmethod
def produce(self, message: Message, msg_group: Optional[str] = None) -> None:
pass
@abc.abstractmethod
def shutdown(self):
pass
class MemoryQueueBackend(MessageQueueBackend):
"""内存消息队列实现"""
def __init__(self):
self._queue = []
def consume(self) -> Optional[Message]:
return self._queue.pop(0) if self._queue else None
def ack(self, message: Message):
return
def produce(self, message: Message, msg_group: Optional[str] = None):
self._queue.append(message)
def shutdown(self):
pass
class AliyunRocketMQQueueBackend(MessageQueueBackend):
def __init__(self, endpoints: str, instance_id: str, topic: str,
has_consumer: bool = False, has_producer: bool = False,
group_id: Optional[str] = None,
ak:Optional[str] = None, sk: Optional[str] = None,
topic_type: Optional[str] = None,
await_duration: int = 20):
if not has_consumer and not has_producer:
raise ValueError("At least one of has_consumer or has_producer must be True.")
self.has_consumer = has_consumer
self.has_producer = has_producer
credentials = Credentials()
# credentials = Credentials("ak", "sk")
mq_config = ClientConfiguration(endpoints, credentials, instance_id)
self.topic = topic
self.group_id = group_id
if has_consumer:
self.consumer = SimpleConsumer(mq_config, group_id, await_duration=await_duration)
self.consumer.startup()
self.consumer.subscribe(self.topic)
if has_producer:
self.producer = rocketmq.Producer(mq_config, (topic,))
self.producer.startup()
self.topic_type = topic_type
def __del__(self):
self.shutdown()
def consume(self, invisible_duration=60) -> Optional[Message]:
if not self.has_consumer:
raise Exception("Consumer not initialized.")
# TODO(zhoutian): invisible_duration实际是不同消息类型不同的,有些消息预期的处理时间会更长
messages = self.consumer.receive(1, invisible_duration)
if not messages:
return None
rmq_message = messages[0]
body = rmq_message.body.decode('utf-8')
logger.debug("[{}]recv message body: {}".format(self.topic, body))
try:
message = Message.from_json(body)
message._rmq_message = rmq_message
except Exception as e:
logger.error("Invalid message: {}. Parsing error: {}".format(body, e))
# 如果消息非法,直接ACK,避免死信
self.consumer.ack(rmq_message)
return None
return message
def ack(self, message: Message):
if not message._rmq_message:
raise ValueError("Message not set with _rmq_message.")
logger.debug("[{}]ack message: {}".format(self.topic, message))
self.consumer.ack(message._rmq_message)
def produce(self, message: Message, msg_group: Optional[str] = None) -> None:
if not self.has_producer:
raise Exception("Producer not initialized.")
message.model_config['use_enum_values'] = False
json_str = message.to_json()
rmq_message = rocketmq.Message()
rmq_message.topic = self.topic
rmq_message.body = json_str.encode('utf-8')
if self.topic_type == 'FIFO':
# 顺序消息队列必须指定消息组
if msg_group is None:
msg_group = f"private:{message.sender}:{message.receiver}"
rmq_message.message_group = msg_group
elif self.topic_type == 'DELAY':
# 延时消息队列必须指定投递时间(秒)
rmq_message.delivery_timestamp = int(message.sendTime / 1000)
self.producer.send(rmq_message)
def shutdown(self):
if self.has_consumer and self.consumer.is_running:
self.consumer.shutdown()
if self.has_producer and self.producer.is_running:
self.producer.shutdown()
if __name__ == '__main__':
logging_service.setup_root_logger()
config = configs.get()
# test Aliyun RocketMQ
endpoints = config['mq']['endpoints']
instance_id = config['mq']['instance_id']
topic = config['mq']['receive_topic']
group_id = config['mq']['receive_group']
queue = AliyunRocketMQQueueBackend(
endpoints, instance_id, topic, True, True, group_id)
while True:
recv_message = queue.consume()
print(recv_message)
if recv_message:
queue.ack(recv_message)
else:
break
send_message = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.CORP_WECHAT,
"user_id_1", "staff_id_0",
None, int(time.time() * 1000))
queue.produce(send_message)
recv_message = queue.consume()
print(recv_message)
if recv_message:
queue.ack(recv_message)
send_message = Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
"user_id_1", "staff_id_0",
"message_queue_backend test", int(time.time() * 1000))
queue.produce(send_message)
recv_message = queue.consume()
print(recv_message)
if recv_message:
queue.ack(recv_message)