import abc
import pymysql.cursors

from typing import Dict, List, Optional

from pqai_agent.database import MySQLManager


class SessionManager(abc.ABC):
    @abc.abstractmethod
    def get_staff_sessions(
        self,
        staff_id: str,
        page_id: int = 1,
        page_size: int = 10,
        session_type: str = "default",
    ) -> List[Dict]:
        pass

    @abc.abstractmethod
    def get_staff_sessions_summary(
        self,
        staff_id: str,
        page_id: int,
        page_size: int,
        status: int,
    ) -> Dict:
        pass

    @abc.abstractmethod
    def get_staff_session_list(
        self, staff_id: str, page_id: int, page_size: int
    ) -> Dict:
        pass

    @abc.abstractmethod
    def get_conversation_list(
        self, staff_id: str, user_id: str, page: Optional[int], page_size: int
    ) -> Dict:
        pass


class MySQLSessionManager(SessionManager):

    def __init__(self, db_config, staff_table, user_table, agent_state_table, chat_history_table):
        self.db = MySQLManager(db_config)
        self.staff_table = staff_table
        self.user_table = user_table
        self.agent_state_table = agent_state_table
        self.chat_history_table = chat_history_table

    def get_staff_sessions(
        self,
        staff_id,
        page_id: int = 1,
        page_size: int = 10,
        session_type: str = "default",
    ) -> List[Dict]:
        """
        :param page_size:
        :param page_id:
        :param session_type:
        :param staff_id:
        :return:
        """
        match session_type:
            case "active":
                sql = f"""
                    select staff_id, current_state, user_id
                    from {self.agent_state_table}
                    where staff_id = %s and update_timestamp >= DATE_SUB(NOW(), INTERVAL 2 HOUR)
                    order by update_timestamp desc;
                """
            case "human_intervention":
                sql = f"""
                    select staff_id, current_state, user_id
                    from {self.agent_state_table}
                    where staff_id = %s and current_state = 5 order by update_timestamp desc;
                """
            case _:
                sql = f"""
                    select t1.staff_id, t1.current_state, t1.user_id, t2.name, t2.iconurl
                    from {self.agent_state_table} t1 join {self.user_table} t2 
                        on t1.user_id = t2.third_party_user_id
                    where t1.staff_id = %s
                    order by 
                        IF(t1.current_state = 5, 0, 1),
                        t1.update_timestamp desc
                    limit {page_size + 1} offset {page_size * (page_id - 1)};
                """

        staff_sessions = self.db.select(
            sql, cursor_type=pymysql.cursors.DictCursor, args=(staff_id,)
        )
        return staff_sessions

    def get_staff_sessions_summary(
        self, staff_id, page_id: int, page_size: int, status: int
    ) -> Dict:
        """
        :param status: staff status(0: unemployed, 1: employed)
        :param staff_id: staff
        :param page_id: page id
        :param page_size: page size
        :return:
        :todo: 未使用 Mysql 连接池,每次查询均需要与 MySQL 建立连接,性能较低,需要优化
        """
        if not staff_id:
            get_staff_query = f"""
                select third_party_user_id, name from {self.staff_table} where status = %s
                limit %s offset %s;
            """
            staff_id_list = self.db.select(
                sql=get_staff_query,
                cursor_type=pymysql.cursors.DictCursor,
                args=(status, page_size + 1, (page_id - 1) * page_size),
            )
            if not staff_id_list:
                return {}

            if len(staff_id_list) > page_size:
                has_next_page = True
                next_page_id = page_id + 1
                staff_id_list = staff_id_list[:page_size]
            else:
                has_next_page = False
                next_page_id = None
        else:
            get_staff_query = f"""
                select third_party_user_id, name from {self.staff_table} 
                where status = %s and third_party_user_id = %s;
            """
            staff_id_list = self.db.select(
                sql=get_staff_query,
                cursor_type=pymysql.cursors.DictCursor,
                args=(status, staff_id),
            )
            if not staff_id_list:
                return {}
            has_next_page = False
            next_page_id = None
        response_data = [
            {
                "staff_id": staff["third_party_user_id"],
                "staff_name": staff["name"],
                "active_sessions": len(
                    self.get_staff_sessions(
                        staff["third_party_user_id"], session_type="active"
                    )
                ),
                "human_intervention_sessions": len(
                    self.get_staff_sessions(
                        staff["third_party_user_id"], session_type="human_intervention"
                    )
                ),
            }
            for staff in staff_id_list
        ]
        return {
            "has_next_page": has_next_page,
            "next_page_id": next_page_id,
            "data": response_data,
        }

    def get_staff_session_list(self, staff_id, page_id: int, page_size: int) -> Dict:
        """
        :param page_size:
        :param page_id:
        :param staff_id:
        :return:
        """
        session_list = self.get_staff_sessions(staff_id, page_id, page_size)
        if len(session_list) > page_size:
            has_next_page = True
            next_page_id = page_id + 1
            session_list = session_list[:page_size]
        else:
            has_next_page = False
            next_page_id = None
        response_data = []
        for session in session_list:
            temp_obj = {}
            user_id = session["user_id"]
            room_id = ":".join(["private", staff_id, user_id])
            select_query = f"""
                select content, sendtime as max_timestamp, msg_type
                from {self.chat_history_table} 
                where roomid = %s
                order by sendtime desc limit %s;
            """
            last_message = self.db.select(
                sql=select_query,
                cursor_type=pymysql.cursors.DictCursor,
                args=(room_id, 1),
            )
            if not last_message:
                temp_obj["message"] = None
                temp_obj["timestamp"] = 0
                temp_obj["msg_type"] = None
            else:
                temp_obj["message"] = last_message[0]["content"]
                temp_obj["timestamp"] = last_message[0]["max_timestamp"]
                temp_obj["msg_type"] = last_message[0]["msg_type"]
            temp_obj["user_id"] = user_id
            temp_obj["user_name"] = session["name"]
            temp_obj["avatar"] = session["iconurl"]
            temp_obj["current_state"] = session["current_state"]
            response_data.append(temp_obj)
        return {
            "staff_id": staff_id,
            "has_next_page": has_next_page,
            "next_page_id": next_page_id,
            "data": response_data,
        }

    def get_conversation_list(
        self, staff_id: str, user_id: str, page: Optional[int], page_size: int
    ):
        """
        :param page_size:
        :param staff_id:
        :param user_id:
        :param page: timestamp
        :return:
        """
        room_id = ":".join(["private", staff_id, user_id])
        if not page:
            fetch_query = f"""
                select t1.sender, t2.name, t1.sendtime, t1.content, t2.iconurl, t1.msg_type
                from {self.chat_history_table} t1
                join {self.user_table} t2 on t1.sender = t2.third_party_user_id
                where roomid = %s
                order by sendtime desc
                limit %s;
            """
            messages = self.db.select(
                sql=fetch_query,
                cursor_type=pymysql.cursors.DictCursor,
                args=(room_id, page_size + 1),
            )
        else:
            fetch_query = f"""
                select t1.sender, t2.name, t1.sendtime, t1.content, t2.iconurl, t1.msg_type
                from {self.chat_history_table} t1
                join {self.user_table} t2 on t1.sender = t2.third_party_user_id
                where t1.roomid = %s and t1.sendtime <= %s
                order by sendtime desc
                limit %s;
            """
            messages = self.db.select(
                sql=fetch_query,
                cursor_type=pymysql.cursors.DictCursor,
                args=(room_id, page, page_size + 1),
            )
        if messages:
            if len(messages) > page_size:
                has_next_page = True
                next_page = messages[-1]["sendtime"]
            else:
                has_next_page = False
                next_page = None
            response_data = [
                {
                    "sender_id": message["sender"],
                    "sender_name": message["name"],
                    "avatar": message["iconurl"],
                    "content": message["content"],
                    "timestamp": message["sendtime"],
                    "msg_type": message["msg_type"],
                    "role": "user" if message["sender"] == user_id else "staff",
                }
                for message in messages
            ]
            return {
                "staff_id": staff_id,
                "user_id": user_id,
                "has_next_page": has_next_page,
                "next_page": next_page,
                "data": response_data,
            }
        else:
            has_next_page = False
            next_page = None
            return {
                "staff_id": staff_id,
                "user_id": user_id,
                "has_next_page": has_next_page,
                "next_page": next_page,
                "data": [],
            }