#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8

import abc
import time
from logging_service import logger
from typing import Dict, Any, Optional
import configs

import logging_service
from message import Message, MessageType, MessageChannel

import rocketmq
from rocketmq import ClientConfiguration, Credentials, SimpleConsumer


class MessageQueueBackend(abc.ABC):
    @abc.abstractmethod
    def consume(self) -> Optional[Message]:
        pass

    @abc.abstractmethod
    def ack(self, message: Message) -> None:
        pass

    @abc.abstractmethod
    def produce(self, message: Message) -> None:
        pass

    @abc.abstractmethod
    def shutdown(self):
        pass

class MemoryQueueBackend(MessageQueueBackend):
    """内存消息队列实现"""
    def __init__(self):
        self._queue = []

    def consume(self) -> Optional[Message]:
        return self._queue.pop(0) if self._queue else None

    def ack(self, message: Message):
        return

    def produce(self, message: Message):
        self._queue.append(message)

    def shutdown(self):
        pass


class AliyunRocketMQQueueBackend(MessageQueueBackend):
    def __init__(self, endpoints: str, instance_id: str, topic: str,
                 has_consumer: bool = False, has_producer: bool = False,
                 group_id: Optional[str] = None,
                 ak:Optional[str] = None, sk: Optional[str] = None):
        if not has_consumer and not has_producer:
            raise ValueError("At least one of has_consumer or has_producer must be True.")
        self.has_consumer = has_consumer
        self.has_producer = has_producer
        credentials = Credentials()
        # credentials = Credentials("ak", "sk")
        mq_config = ClientConfiguration(endpoints, credentials, instance_id)
        self.topic = topic
        self.group_id = group_id
        if has_consumer:
            self.consumer = SimpleConsumer(mq_config, group_id)
            self.consumer.startup()
            self.consumer.subscribe(self.topic)
        if has_producer:
            self.producer = rocketmq.Producer(mq_config, (topic,))
            self.producer.startup()

    def __del__(self):
        self.shutdown()

    def consume(self, invisible_duration=60) -> Optional[Message]:
        if not self.has_consumer:
            raise Exception("Consumer not initialized.")
        # TODO(zhoutian): invisible_duration实际是不同消息类型不同的,有些消息预期的处理时间会更长
        messages = self.consumer.receive(1, invisible_duration)
        if not messages:
            return None
        rmq_message = messages[0]
        body = rmq_message.body.decode('utf-8')
        logger.debug("recv message body: {}".format(body))
        try:
            message = Message.from_json(body)
            message._rmq_message = rmq_message
        except Exception as e:
            logger.error("Invalid message: {}. Parsing error: {}".format(body, e))
            # 如果消息非法,直接ACK,避免死信
            self.consumer.ack(rmq_message)
            return None
        return message

    def ack(self, message: Message):
        if not message._rmq_message:
            raise ValueError("Message not set with _rmq_message.")
        logger.debug("ack message: {}".format(message))
        self.consumer.ack(message._rmq_message)

    def produce(self, message: Message) -> None:
        if not self.has_producer:
            raise Exception("Producer not initialized.")
        message.model_config['use_enum_values'] = False
        json_str = message.to_json()
        rmq_message = rocketmq.Message()
        rmq_message.topic = self.topic
        rmq_message.body = json_str.encode('utf-8')
        # 顺序消息队列必须指定消息组
        rmq_message.message_group = "agent_system"
        self.producer.send(rmq_message)

    def shutdown(self):
        if self.has_consumer:
            self.consumer.shutdown()
        if self.has_producer:
            self.producer.shutdown()

if __name__ == '__main__':
    logging_service.setup_root_logger()
    config = configs.get()
    # test Aliyun RocketMQ
    endpoints = config['mq']['endpoints']
    instance_id = config['mq']['instance_id']
    topic = config['mq']['receive_topic']
    group_id = config['mq']['receive_group']

    queue = AliyunRocketMQQueueBackend(
        endpoints, instance_id, topic, True, True, group_id)

    while True:
        recv_message = queue.consume()
        print(recv_message)
        if recv_message:
            queue.ack(recv_message)
        else:
            break

    send_message = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.CORP_WECHAT,
                                 "user_id_1", "staff_id_0",
                                 None, int(time.time() * 1000))
    queue.produce(send_message)
    recv_message = queue.consume()
    print(recv_message)
    if recv_message:
        queue.ack(recv_message)

    send_message = Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
                                 "user_id_1", "staff_id_0",
                                 "message_queue_backend test", int(time.time() * 1000))
    queue.produce(send_message)
    recv_message = queue.consume()
    print(recv_message)
    if recv_message:
        queue.ack(recv_message)