|
@@ -3,9 +3,16 @@
|
|
|
# vim:fenc=utf-8
|
|
|
|
|
|
import abc
|
|
|
+import time
|
|
|
+import logging
|
|
|
from typing import Dict, Any, Optional
|
|
|
+import configs
|
|
|
|
|
|
-from message import Message
|
|
|
+import logging_service
|
|
|
+from message import Message, MessageType, MessageChannel
|
|
|
+
|
|
|
+import rocketmq
|
|
|
+from rocketmq import ClientConfiguration, Credentials, SimpleConsumer
|
|
|
|
|
|
|
|
|
class MessageQueueBackend(abc.ABC):
|
|
@@ -13,10 +20,18 @@ class MessageQueueBackend(abc.ABC):
|
|
|
def consume(self) -> Optional[Message]:
|
|
|
pass
|
|
|
|
|
|
+ @abc.abstractmethod
|
|
|
+ def ack(self, message: Message) -> None:
|
|
|
+ pass
|
|
|
+
|
|
|
@abc.abstractmethod
|
|
|
def produce(self, message: Message) -> None:
|
|
|
pass
|
|
|
|
|
|
+ @abc.abstractmethod
|
|
|
+ def shutdown(self):
|
|
|
+ pass
|
|
|
+
|
|
|
class MemoryQueueBackend(MessageQueueBackend):
|
|
|
"""内存消息队列实现"""
|
|
|
def __init__(self):
|
|
@@ -25,5 +40,118 @@ class MemoryQueueBackend(MessageQueueBackend):
|
|
|
def consume(self) -> Optional[Message]:
|
|
|
return self._queue.pop(0) if self._queue else None
|
|
|
|
|
|
+ def ack(self, message: Message):
|
|
|
+ return
|
|
|
+
|
|
|
def produce(self, message: Message):
|
|
|
- self._queue.append(message)
|
|
|
+ self._queue.append(message)
|
|
|
+
|
|
|
+ def shutdown(self):
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+class AliyunRocketMQQueueBackend(MessageQueueBackend):
|
|
|
+ def __init__(self, endpoints: str, instance_id: str, topic: str,
|
|
|
+ has_consumer: bool = False, has_producer: bool = False,
|
|
|
+ group_id: Optional[str] = None,
|
|
|
+ ak:Optional[str] = None, sk: Optional[str] = None):
|
|
|
+ if not has_consumer and not has_producer:
|
|
|
+ raise ValueError("At least one of has_consumer or has_producer must be True.")
|
|
|
+ self.has_consumer = has_consumer
|
|
|
+ self.has_producer = has_producer
|
|
|
+ credentials = Credentials()
|
|
|
+ # credentials = Credentials("ak", "sk")
|
|
|
+ mq_config = ClientConfiguration(endpoints, credentials, instance_id)
|
|
|
+ self.topic = topic
|
|
|
+ self.group_id = group_id
|
|
|
+ if has_consumer:
|
|
|
+ self.consumer = SimpleConsumer(mq_config, group_id)
|
|
|
+ self.consumer.startup()
|
|
|
+ self.consumer.subscribe(self.topic)
|
|
|
+ if has_producer:
|
|
|
+ self.producer = rocketmq.Producer(mq_config, (topic,))
|
|
|
+ self.producer.startup()
|
|
|
+
|
|
|
+ def __del__(self):
|
|
|
+ self.shutdown()
|
|
|
+
|
|
|
+ def consume(self) -> Optional[Message]:
|
|
|
+ if not self.has_consumer:
|
|
|
+ raise Exception("Consumer not initialized.")
|
|
|
+ messages = self.consumer.receive(1, 10)
|
|
|
+ if not messages:
|
|
|
+ return None
|
|
|
+ rmq_message = messages[0]
|
|
|
+ body = rmq_message.body.decode('utf-8')
|
|
|
+ logging.debug("recv message body: {}".format(body))
|
|
|
+ try:
|
|
|
+ message = Message.from_json(body)
|
|
|
+ message._rmq_message = rmq_message
|
|
|
+ except Exception as e:
|
|
|
+ logging.error("Invalid message: {}. Parsing error: {}".format(body, e))
|
|
|
+ # 如果消息非法,直接ACK,避免死信
|
|
|
+ self.consumer.ack(rmq_message)
|
|
|
+ return None
|
|
|
+ return message
|
|
|
+
|
|
|
+ def ack(self, message: Message):
|
|
|
+ if not message._rmq_message:
|
|
|
+ raise ValueError("Message not set with _rmq_message.")
|
|
|
+ logging.debug("ack message: {}".format(message))
|
|
|
+ self.consumer.ack(message._rmq_message)
|
|
|
+
|
|
|
+ def produce(self, message: Message) -> None:
|
|
|
+ if not self.has_producer:
|
|
|
+ raise Exception("Producer not initialized.")
|
|
|
+ message.model_config['use_enum_values'] = False
|
|
|
+ json_str = message.to_json()
|
|
|
+ rmq_message = rocketmq.Message()
|
|
|
+ rmq_message.topic = self.topic
|
|
|
+ rmq_message.body = json_str.encode('utf-8')
|
|
|
+ # 顺序消息队列必须指定消息组
|
|
|
+ rmq_message.message_group = "agent_system"
|
|
|
+ self.producer.send(rmq_message)
|
|
|
+
|
|
|
+ def shutdown(self):
|
|
|
+ if self.has_consumer:
|
|
|
+ self.consumer.shutdown()
|
|
|
+ if self.has_producer:
|
|
|
+ self.producer.shutdown()
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ logging_service.setup_root_logger()
|
|
|
+ config = configs.get()
|
|
|
+ # test Aliyun RocketMQ
|
|
|
+ endpoints = config['mq']['endpoints']
|
|
|
+ instance_id = config['mq']['instance_id']
|
|
|
+ topic = config['mq']['receive_topic']
|
|
|
+ group_id = config['mq']['receive_group']
|
|
|
+
|
|
|
+ queue = AliyunRocketMQQueueBackend(
|
|
|
+ endpoints, instance_id, topic, True, True, group_id)
|
|
|
+
|
|
|
+ while True:
|
|
|
+ recv_message = queue.consume()
|
|
|
+ print(recv_message)
|
|
|
+ if recv_message:
|
|
|
+ queue.ack(recv_message)
|
|
|
+ else:
|
|
|
+ break
|
|
|
+
|
|
|
+ send_message = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.CORP_WECHAT,
|
|
|
+ "user_id_1", "staff_id_0",
|
|
|
+ None, int(time.time() * 1000))
|
|
|
+ queue.produce(send_message)
|
|
|
+ recv_message = queue.consume()
|
|
|
+ print(recv_message)
|
|
|
+ if recv_message:
|
|
|
+ queue.ack(recv_message)
|
|
|
+
|
|
|
+ send_message = Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
|
|
|
+ "user_id_1", "staff_id_0",
|
|
|
+ "message_queue_backend test", int(time.time() * 1000))
|
|
|
+ queue.produce(send_message)
|
|
|
+ recv_message = queue.consume()
|
|
|
+ print(recv_message)
|
|
|
+ if recv_message:
|
|
|
+ queue.ack(recv_message)
|