123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158 |
- #! /usr/bin/env python
- # -*- coding: utf-8 -*-
- # vim:fenc=utf-8
- import abc
- import time
- from logging_service import logger
- from typing import Dict, Any, Optional
- import configs
- import logging_service
- from message import Message, MessageType, MessageChannel
- import rocketmq
- from rocketmq import ClientConfiguration, Credentials, SimpleConsumer
- class MessageQueueBackend(abc.ABC):
- @abc.abstractmethod
- def consume(self) -> Optional[Message]:
- pass
- @abc.abstractmethod
- def ack(self, message: Message) -> None:
- pass
- @abc.abstractmethod
- def produce(self, message: Message) -> None:
- pass
- @abc.abstractmethod
- def shutdown(self):
- pass
- class MemoryQueueBackend(MessageQueueBackend):
- """内存消息队列实现"""
- def __init__(self):
- self._queue = []
- def consume(self) -> Optional[Message]:
- return self._queue.pop(0) if self._queue else None
- def ack(self, message: Message):
- return
- def produce(self, message: Message):
- self._queue.append(message)
- def shutdown(self):
- pass
- class AliyunRocketMQQueueBackend(MessageQueueBackend):
- def __init__(self, endpoints: str, instance_id: str, topic: str,
- has_consumer: bool = False, has_producer: bool = False,
- group_id: Optional[str] = None,
- ak:Optional[str] = None, sk: Optional[str] = None):
- if not has_consumer and not has_producer:
- raise ValueError("At least one of has_consumer or has_producer must be True.")
- self.has_consumer = has_consumer
- self.has_producer = has_producer
- credentials = Credentials()
- # credentials = Credentials("ak", "sk")
- mq_config = ClientConfiguration(endpoints, credentials, instance_id)
- self.topic = topic
- self.group_id = group_id
- if has_consumer:
- self.consumer = SimpleConsumer(mq_config, group_id)
- self.consumer.startup()
- self.consumer.subscribe(self.topic)
- if has_producer:
- self.producer = rocketmq.Producer(mq_config, (topic,))
- self.producer.startup()
- def __del__(self):
- self.shutdown()
- def consume(self, invisible_duration=60) -> Optional[Message]:
- if not self.has_consumer:
- raise Exception("Consumer not initialized.")
- # TODO(zhoutian): invisible_duration实际是不同消息类型不同的,有些消息预期的处理时间会更长
- messages = self.consumer.receive(1, invisible_duration)
- if not messages:
- return None
- rmq_message = messages[0]
- body = rmq_message.body.decode('utf-8')
- logger.debug("recv message body: {}".format(body))
- try:
- message = Message.from_json(body)
- message._rmq_message = rmq_message
- except Exception as e:
- logger.error("Invalid message: {}. Parsing error: {}".format(body, e))
- # 如果消息非法,直接ACK,避免死信
- self.consumer.ack(rmq_message)
- return None
- return message
- def ack(self, message: Message):
- if not message._rmq_message:
- raise ValueError("Message not set with _rmq_message.")
- logger.debug("ack message: {}".format(message))
- self.consumer.ack(message._rmq_message)
- def produce(self, message: Message) -> None:
- if not self.has_producer:
- raise Exception("Producer not initialized.")
- message.model_config['use_enum_values'] = False
- json_str = message.to_json()
- rmq_message = rocketmq.Message()
- rmq_message.topic = self.topic
- rmq_message.body = json_str.encode('utf-8')
- # 顺序消息队列必须指定消息组
- rmq_message.message_group = "agent_system"
- self.producer.send(rmq_message)
- def shutdown(self):
- if self.has_consumer:
- self.consumer.shutdown()
- if self.has_producer:
- self.producer.shutdown()
- if __name__ == '__main__':
- logging_service.setup_root_logger()
- config = configs.get()
- # test Aliyun RocketMQ
- endpoints = config['mq']['endpoints']
- instance_id = config['mq']['instance_id']
- topic = config['mq']['receive_topic']
- group_id = config['mq']['receive_group']
- queue = AliyunRocketMQQueueBackend(
- endpoints, instance_id, topic, True, True, group_id)
- while True:
- recv_message = queue.consume()
- print(recv_message)
- if recv_message:
- queue.ack(recv_message)
- else:
- break
- send_message = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.CORP_WECHAT,
- "user_id_1", "staff_id_0",
- None, int(time.time() * 1000))
- queue.produce(send_message)
- recv_message = queue.consume()
- print(recv_message)
- if recv_message:
- queue.ack(recv_message)
- send_message = Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
- "user_id_1", "staff_id_0",
- "message_queue_backend test", int(time.time() * 1000))
- queue.produce(send_message)
- recv_message = queue.consume()
- print(recv_message)
- if recv_message:
- queue.ack(recv_message)
|