Selaa lähdekoodia

Update message_queue_backend: support aliyun RMQ

StrayWarrior 2 viikkoa sitten
vanhempi
commit
237e19439a
2 muutettua tiedostoa jossa 144 lisäystä ja 2 poistoa
  1. 14 0
      configs/__init__.py
  2. 130 2
      message_queue_backend.py

+ 14 - 0
configs/__init__.py

@@ -0,0 +1,14 @@
+#! /usr/bin/env python
+# -*- coding: utf-8 -*-
+# vim:fenc=utf-8
+
+
+import os
+import yaml
+
+def get():
+    dirname = os.path.dirname(os.path.abspath(__file__))
+    env = os.environ.get('AI_AGENT_ENV', 'dev')
+    if env not in ('dev', 'pre', 'prod'):
+        raise ValueError(f"Invalid environment: {env}. Expected one of ('dev', 'pre', 'prod').")
+    return yaml.safe_load(open(f'{dirname}/{env}.yaml').read())

+ 130 - 2
message_queue_backend.py

@@ -3,9 +3,16 @@
 # vim:fenc=utf-8
 
 import abc
+import time
+import logging
 from typing import Dict, Any, Optional
+import configs
 
-from message import Message
+import logging_service
+from message import Message, MessageType, MessageChannel
+
+import rocketmq
+from rocketmq import ClientConfiguration, Credentials, SimpleConsumer
 
 
 class MessageQueueBackend(abc.ABC):
@@ -13,10 +20,18 @@ class MessageQueueBackend(abc.ABC):
     def consume(self) -> Optional[Message]:
         pass
 
+    @abc.abstractmethod
+    def ack(self, message: Message) -> None:
+        pass
+
     @abc.abstractmethod
     def produce(self, message: Message) -> None:
         pass
 
+    @abc.abstractmethod
+    def shutdown(self):
+        pass
+
 class MemoryQueueBackend(MessageQueueBackend):
     """内存消息队列实现"""
     def __init__(self):
@@ -25,5 +40,118 @@ class MemoryQueueBackend(MessageQueueBackend):
     def consume(self) -> Optional[Message]:
         return self._queue.pop(0) if self._queue else None
 
+    def ack(self, message: Message):
+        return
+
     def produce(self, message: Message):
-        self._queue.append(message)
+        self._queue.append(message)
+
+    def shutdown(self):
+        pass
+
+
+class AliyunRocketMQQueueBackend(MessageQueueBackend):
+    def __init__(self, endpoints: str, instance_id: str, topic: str,
+                 has_consumer: bool = False, has_producer: bool = False,
+                 group_id: Optional[str] = None,
+                 ak:Optional[str] = None, sk: Optional[str] = None):
+        if not has_consumer and not has_producer:
+            raise ValueError("At least one of has_consumer or has_producer must be True.")
+        self.has_consumer = has_consumer
+        self.has_producer = has_producer
+        credentials = Credentials()
+        # credentials = Credentials("ak", "sk")
+        mq_config = ClientConfiguration(endpoints, credentials, instance_id)
+        self.topic = topic
+        self.group_id = group_id
+        if has_consumer:
+            self.consumer = SimpleConsumer(mq_config, group_id)
+            self.consumer.startup()
+            self.consumer.subscribe(self.topic)
+        if has_producer:
+            self.producer = rocketmq.Producer(mq_config, (topic,))
+            self.producer.startup()
+
+    def __del__(self):
+        self.shutdown()
+
+    def consume(self) -> Optional[Message]:
+        if not self.has_consumer:
+            raise Exception("Consumer not initialized.")
+        messages = self.consumer.receive(1, 10)
+        if not messages:
+            return None
+        rmq_message = messages[0]
+        body = rmq_message.body.decode('utf-8')
+        logging.debug("recv message body: {}".format(body))
+        try:
+            message = Message.from_json(body)
+            message._rmq_message = rmq_message
+        except Exception as e:
+            logging.error("Invalid message: {}. Parsing error: {}".format(body, e))
+            # 如果消息非法,直接ACK,避免死信
+            self.consumer.ack(rmq_message)
+            return None
+        return message
+
+    def ack(self, message: Message):
+        if not message._rmq_message:
+            raise ValueError("Message not set with _rmq_message.")
+        logging.debug("ack message: {}".format(message))
+        self.consumer.ack(message._rmq_message)
+
+    def produce(self, message: Message) -> None:
+        if not self.has_producer:
+            raise Exception("Producer not initialized.")
+        message.model_config['use_enum_values'] = False
+        json_str = message.to_json()
+        rmq_message = rocketmq.Message()
+        rmq_message.topic = self.topic
+        rmq_message.body = json_str.encode('utf-8')
+        # 顺序消息队列必须指定消息组
+        rmq_message.message_group = "agent_system"
+        self.producer.send(rmq_message)
+
+    def shutdown(self):
+        if self.has_consumer:
+            self.consumer.shutdown()
+        if self.has_producer:
+            self.producer.shutdown()
+
+if __name__ == '__main__':
+    logging_service.setup_root_logger()
+    config = configs.get()
+    # test Aliyun RocketMQ
+    endpoints = config['mq']['endpoints']
+    instance_id = config['mq']['instance_id']
+    topic = config['mq']['receive_topic']
+    group_id = config['mq']['receive_group']
+
+    queue = AliyunRocketMQQueueBackend(
+        endpoints, instance_id, topic, True, True, group_id)
+
+    while True:
+        recv_message = queue.consume()
+        print(recv_message)
+        if recv_message:
+            queue.ack(recv_message)
+        else:
+            break
+
+    send_message = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.CORP_WECHAT,
+                                 "user_id_1", "staff_id_0",
+                                 None, int(time.time() * 1000))
+    queue.produce(send_message)
+    recv_message = queue.consume()
+    print(recv_message)
+    if recv_message:
+        queue.ack(recv_message)
+
+    send_message = Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
+                                 "user_id_1", "staff_id_0",
+                                 "message_queue_backend test", int(time.time() * 1000))
+    queue.produce(send_message)
+    recv_message = queue.consume()
+    print(recv_message)
+    if recv_message:
+        queue.ack(recv_message)