message_queue_backend.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. import abc
  5. import time
  6. from logging_service import logger
  7. from typing import Dict, Any, Optional
  8. import configs
  9. import logging_service
  10. from message import Message, MessageType, MessageChannel
  11. import rocketmq
  12. from rocketmq import ClientConfiguration, Credentials, SimpleConsumer
  13. class MessageQueueBackend(abc.ABC):
  14. @abc.abstractmethod
  15. def consume(self) -> Optional[Message]:
  16. pass
  17. @abc.abstractmethod
  18. def ack(self, message: Message) -> None:
  19. pass
  20. @abc.abstractmethod
  21. def produce(self, message: Message) -> None:
  22. pass
  23. @abc.abstractmethod
  24. def shutdown(self):
  25. pass
  26. class MemoryQueueBackend(MessageQueueBackend):
  27. """内存消息队列实现"""
  28. def __init__(self):
  29. self._queue = []
  30. def consume(self) -> Optional[Message]:
  31. return self._queue.pop(0) if self._queue else None
  32. def ack(self, message: Message):
  33. return
  34. def produce(self, message: Message):
  35. self._queue.append(message)
  36. def shutdown(self):
  37. pass
  38. class AliyunRocketMQQueueBackend(MessageQueueBackend):
  39. def __init__(self, endpoints: str, instance_id: str, topic: str,
  40. has_consumer: bool = False, has_producer: bool = False,
  41. group_id: Optional[str] = None,
  42. ak:Optional[str] = None, sk: Optional[str] = None):
  43. if not has_consumer and not has_producer:
  44. raise ValueError("At least one of has_consumer or has_producer must be True.")
  45. self.has_consumer = has_consumer
  46. self.has_producer = has_producer
  47. credentials = Credentials()
  48. # credentials = Credentials("ak", "sk")
  49. mq_config = ClientConfiguration(endpoints, credentials, instance_id)
  50. self.topic = topic
  51. self.group_id = group_id
  52. if has_consumer:
  53. self.consumer = SimpleConsumer(mq_config, group_id)
  54. self.consumer.startup()
  55. self.consumer.subscribe(self.topic)
  56. if has_producer:
  57. self.producer = rocketmq.Producer(mq_config, (topic,))
  58. self.producer.startup()
  59. def __del__(self):
  60. self.shutdown()
  61. def consume(self, invisible_duration=60) -> Optional[Message]:
  62. if not self.has_consumer:
  63. raise Exception("Consumer not initialized.")
  64. # TODO(zhoutian): invisible_duration实际是不同消息类型不同的,有些消息预期的处理时间会更长
  65. messages = self.consumer.receive(1, invisible_duration)
  66. if not messages:
  67. return None
  68. rmq_message = messages[0]
  69. body = rmq_message.body.decode('utf-8')
  70. logger.debug("recv message body: {}".format(body))
  71. try:
  72. message = Message.from_json(body)
  73. message._rmq_message = rmq_message
  74. except Exception as e:
  75. logger.error("Invalid message: {}. Parsing error: {}".format(body, e))
  76. # 如果消息非法,直接ACK,避免死信
  77. self.consumer.ack(rmq_message)
  78. return None
  79. return message
  80. def ack(self, message: Message):
  81. if not message._rmq_message:
  82. raise ValueError("Message not set with _rmq_message.")
  83. logger.debug("ack message: {}".format(message))
  84. self.consumer.ack(message._rmq_message)
  85. def produce(self, message: Message) -> None:
  86. if not self.has_producer:
  87. raise Exception("Producer not initialized.")
  88. message.model_config['use_enum_values'] = False
  89. json_str = message.to_json()
  90. rmq_message = rocketmq.Message()
  91. rmq_message.topic = self.topic
  92. rmq_message.body = json_str.encode('utf-8')
  93. # 顺序消息队列必须指定消息组
  94. rmq_message.message_group = "agent_system"
  95. self.producer.send(rmq_message)
  96. def shutdown(self):
  97. if self.has_consumer:
  98. self.consumer.shutdown()
  99. if self.has_producer:
  100. self.producer.shutdown()
  101. if __name__ == '__main__':
  102. logging_service.setup_root_logger()
  103. config = configs.get()
  104. # test Aliyun RocketMQ
  105. endpoints = config['mq']['endpoints']
  106. instance_id = config['mq']['instance_id']
  107. topic = config['mq']['receive_topic']
  108. group_id = config['mq']['receive_group']
  109. queue = AliyunRocketMQQueueBackend(
  110. endpoints, instance_id, topic, True, True, group_id)
  111. while True:
  112. recv_message = queue.consume()
  113. print(recv_message)
  114. if recv_message:
  115. queue.ack(recv_message)
  116. else:
  117. break
  118. send_message = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.CORP_WECHAT,
  119. "user_id_1", "staff_id_0",
  120. None, int(time.time() * 1000))
  121. queue.produce(send_message)
  122. recv_message = queue.consume()
  123. print(recv_message)
  124. if recv_message:
  125. queue.ack(recv_message)
  126. send_message = Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
  127. "user_id_1", "staff_id_0",
  128. "message_queue_backend test", int(time.time() * 1000))
  129. queue.produce(send_message)
  130. recv_message = queue.consume()
  131. print(recv_message)
  132. if recv_message:
  133. queue.ack(recv_message)