message_queue_backend.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. import abc
  5. import time
  6. from typing import Optional
  7. import rocketmq
  8. from rocketmq import ClientConfiguration, Credentials, SimpleConsumer
  9. from pqai_agent import configs
  10. from pqai_agent.logging import logger
  11. from pqai_agent.mq_message import MqMessage, MessageType, MessageChannel
  12. class MessageQueueBackend(abc.ABC):
  13. @abc.abstractmethod
  14. def consume(self) -> Optional[MqMessage]:
  15. pass
  16. @abc.abstractmethod
  17. def ack(self, message: MqMessage) -> None:
  18. pass
  19. @abc.abstractmethod
  20. def produce(self, message: MqMessage, msg_group: Optional[str] = None) -> None:
  21. pass
  22. @abc.abstractmethod
  23. def shutdown(self):
  24. pass
  25. class MemoryQueueBackend(MessageQueueBackend):
  26. """内存消息队列实现"""
  27. def __init__(self):
  28. self._queue = []
  29. def consume(self) -> Optional[MqMessage]:
  30. return self._queue.pop(0) if self._queue else None
  31. def ack(self, message: MqMessage):
  32. return
  33. def produce(self, message: MqMessage, msg_group: Optional[str] = None):
  34. self._queue.append(message)
  35. def shutdown(self):
  36. pass
  37. class AliyunRocketMQQueueBackend(MessageQueueBackend):
  38. def __init__(self, endpoints: str, instance_id: str, topic: str,
  39. has_consumer: bool = False, has_producer: bool = False,
  40. group_id: Optional[str] = None,
  41. ak:Optional[str] = None, sk: Optional[str] = None,
  42. topic_type: Optional[str] = None,
  43. await_duration: int = 20):
  44. if not has_consumer and not has_producer:
  45. raise ValueError("At least one of has_consumer or has_producer must be True.")
  46. self.has_consumer = has_consumer
  47. self.has_producer = has_producer
  48. credentials = Credentials()
  49. # credentials = Credentials("ak", "sk")
  50. mq_config = ClientConfiguration(endpoints, credentials, instance_id)
  51. self.topic = topic
  52. self.group_id = group_id
  53. if has_consumer:
  54. self.consumer = SimpleConsumer(mq_config, group_id, await_duration=await_duration)
  55. self.consumer.startup()
  56. self.consumer.subscribe(self.topic)
  57. if has_producer:
  58. self.producer = rocketmq.Producer(mq_config, (topic,))
  59. self.producer.startup()
  60. self.topic_type = topic_type
  61. def __del__(self):
  62. self.shutdown()
  63. def consume(self, invisible_duration=60) -> Optional[MqMessage]:
  64. if not self.has_consumer:
  65. raise Exception("Consumer not initialized.")
  66. # TODO(zhoutian): invisible_duration实际是不同消息类型不同的,有些消息预期的处理时间会更长
  67. messages = self.consumer.receive(1, invisible_duration)
  68. if not messages:
  69. return None
  70. rmq_message = messages[0]
  71. body = rmq_message.body.decode('utf-8')
  72. logger.debug("[{}]recv message body: {}".format(self.topic, body))
  73. try:
  74. message = MqMessage.from_json(body)
  75. message._rmq_message = rmq_message
  76. except Exception as e:
  77. logger.error("Invalid message: {}. Parsing error: {}".format(body, e))
  78. # 如果消息非法,直接ACK,避免死信
  79. self.consumer.ack(rmq_message)
  80. return None
  81. return message
  82. def ack(self, message: MqMessage):
  83. if not message._rmq_message:
  84. raise ValueError("MqMessage not set with _rmq_message.")
  85. logger.debug("[{}]ack message: {}".format(self.topic, message))
  86. self.consumer.ack(message._rmq_message)
  87. def produce(self, message: MqMessage, msg_group: Optional[str] = None) -> None:
  88. if not self.has_producer:
  89. raise Exception("Producer not initialized.")
  90. message.model_config['use_enum_values'] = False
  91. json_str = message.to_json()
  92. rmq_message = rocketmq.Message()
  93. rmq_message.topic = self.topic
  94. rmq_message.body = json_str.encode('utf-8')
  95. if self.topic_type == 'FIFO':
  96. # 顺序消息队列必须指定消息组
  97. if msg_group is None:
  98. msg_group = f"private:{message.sender}:{message.receiver}"
  99. rmq_message.message_group = msg_group
  100. elif self.topic_type == 'DELAY':
  101. # 延时消息队列必须指定投递时间(秒)
  102. rmq_message.delivery_timestamp = int(message.sendTime / 1000)
  103. self.producer.send(rmq_message)
  104. def shutdown(self):
  105. if self.has_consumer and self.consumer.is_running:
  106. self.consumer.shutdown()
  107. if self.has_producer and self.producer.is_running:
  108. self.producer.shutdown()
  109. if __name__ == '__main__':
  110. logging_service.setup_root_logger()
  111. config = configs.get()
  112. # test Aliyun RocketMQ
  113. endpoints = config['mq']['endpoints']
  114. instance_id = config['mq']['instance_id']
  115. topic = config['mq']['receive_topic']
  116. group_id = config['mq']['receive_group']
  117. queue = AliyunRocketMQQueueBackend(
  118. endpoints, instance_id, topic, True, True, group_id)
  119. while True:
  120. recv_message = queue.consume()
  121. print(recv_message)
  122. if recv_message:
  123. queue.ack(recv_message)
  124. else:
  125. break
  126. send_message = MqMessage.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.CORP_WECHAT,
  127. "user_id_1", "staff_id_0",
  128. None, int(time.time() * 1000))
  129. queue.produce(send_message)
  130. recv_message = queue.consume()
  131. print(recv_message)
  132. if recv_message:
  133. queue.ack(recv_message)
  134. send_message = MqMessage.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
  135. "user_id_1", "staff_id_0",
  136. "message_queue_backend test", int(time.time() * 1000))
  137. queue.produce(send_message)
  138. recv_message = queue.consume()
  139. print(recv_message)
  140. if recv_message:
  141. queue.ack(recv_message)