#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8

from logging_service import logger
from typing import Dict, Optional, Tuple, Any, List
import json
import time
import os
import abc

import pymysql.cursors

import configs
from database import MySQLManager


class UserManager(abc.ABC):
    @abc.abstractmethod
    def get_user_profile(self, user_id) -> Dict:
        pass

    @abc.abstractmethod
    def save_user_profile(self, user_id, profile: Dict) -> None:
        pass

    @abc.abstractmethod
    def list_all_users(self):
        pass

    @abc.abstractmethod
    def get_staff_profile(self, staff_id) -> Dict:
        #FIXME(zhoutian): 重新设计用户和员工数据管理模型
        pass

    @staticmethod
    def get_default_profile(**kwargs) -> Dict:
        default_profile = {
            "name": "",
            "nickname": "",
            "preferred_nickname": "",
            "gender": "未知",
            "age": 0,
            "region": '',
            "interests": [],
            "family_members": {},
            "health_conditions": [],
            "medications": [],
            "reminder_preferences": {
                "medication": True,
                "health": True,
                "weather": True,
                "news": False
            },
            "interaction_style": "standard",  # standard, verbose, concise
            "interaction_frequency": "medium",  # low, medium, high
            "last_topics": [],
            "created_at": int(time.time() * 1000),
            "human_intervention_history": []
        }
        for key, value in kwargs.items():
            if key in default_profile:
                default_profile[key] = value
        return default_profile

    def list_users(self, **kwargs) -> List[Dict]:
        pass

class UserRelationManager(abc.ABC):
    @abc.abstractmethod
    def list_staffs(self):
        pass

    @abc.abstractmethod
    def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
        pass

    @abc.abstractmethod
    def list_staff_users(self) -> List[Dict]:
        pass

    @abc.abstractmethod
    def get_user_tags(self, user_id: str) -> List[str]:
        pass

class LocalUserManager(UserManager):
    def get_user_profile(self, user_id) -> Dict:
        """加载用户个人资料,如不存在则创建默认资料。主要用于本地调试"""
        try:
            with open(f"user_profiles/{user_id}.json", "r", encoding="utf-8") as f:
                return json.load(f)
        except FileNotFoundError:
            # 创建默认用户资料
            default_profile = self.get_default_profile()
            self.save_user_profile(user_id, default_profile)
            return default_profile

    def save_user_profile(self, user_id, profile: Dict) -> None:
        if not user_id:
            raise Exception("Invalid user_id: {}".format(user_id))
        with open(f"user_profiles/{user_id}.json", "w", encoding="utf-8") as f:
            json.dump(profile, f, ensure_ascii=False, indent=2)

    def list_all_users(self):
        user_ids = []
        for root, dirs, files in os.walk('user_profiles/'):
            for file in files:
                if file.endswith('.json'):
                    user_ids.append(os.path.splitext(file)[0])
        return user_ids

    def get_staff_profile(self, staff_id) -> Dict:
        return {}

    def list_users(self, **kwargs) -> List[Dict]:
        pass

class MySQLUserManager(UserManager):
    def __init__(self, db_config, table_name, staff_table):
        self.db = MySQLManager(db_config)
        self.table_name = table_name
        self.staff_table = staff_table

    def get_user_profile(self, user_id) -> Dict:
        sql = f"SELECT name, wxid, profile_data_v1, gender" \
              f" FROM {self.table_name} WHERE third_party_user_id = {user_id}"
        data = self.db.select(sql, pymysql.cursors.DictCursor)
        if not data:
            logger.error(f"user[{user_id}] not found")
            return {}
        data = data[0]
        gender_map = {0: '未知', 1: '男', 2: '女', None: '未知'}
        gender = gender_map[data['gender']]
        default_profile = self.get_default_profile(nickname=data['name'], gender=gender)
        if not data['profile_data_v1']:
            logger.warning(f"user[{user_id}] profile not found, create a default one")
            self.save_user_profile(user_id, default_profile)
            return default_profile
        else:
            profile = json.loads(data['profile_data_v1'])
            # 资料条目有增加时,需合并更新
            entry_added = False
            for key, value in default_profile.items():
                if key not in profile:
                    logger.debug(f"user[{user_id}] add profile key[{key}] value[{value}]")
                    profile[key] = value
                    entry_added = True
            if entry_added:
                self.save_user_profile(user_id, profile)
            return profile

    def save_user_profile(self, user_id, profile: Dict) -> None:
        if not user_id:
            raise Exception("Invalid user_id: {}".format(user_id))
        sql = f"UPDATE {self.table_name} SET profile_data_v1 = %s WHERE third_party_user_id = {user_id}"
        self.db.execute(sql, (json.dumps(profile),))

    def list_all_users(self):
        sql = f"SELECT third_party_user_id FROM {self.table_name}"
        data = self.db.select(sql, pymysql.cursors.DictCursor)
        return [user['third_party_user_id'] for user in data]

    def get_staff_profile(self, staff_id) -> Dict:
        if not self.staff_table:
            raise Exception("staff_table is not set")
        sql = f"SELECT agent_name, agent_gender, agent_age, agent_region " \
              f"FROM {self.staff_table} WHERE third_party_user_id = '{staff_id}'"
        data = self.db.select(sql, pymysql.cursors.DictCursor)
        if not data:
            logger.error(f"staff[{staff_id}] not found")
            return {}
        profile = data[0]
        # 转换性别格式
        gender_map = {0: '未知', 1: '男', 2: '女', None: '未知'}
        profile['agent_gender'] = gender_map[profile['agent_gender']]
        return profile

    def list_users(self, **kwargs) -> List[Dict]:
        user_union_id = kwargs.get('user_union_id', None)
        user_name = kwargs.get('user_name', None)
        if not user_union_id and not user_name:
            raise Exception("user_union_id or user_name is required")
        sql = f"SELECT third_party_user_id, wxid, name, iconurl, gender FROM {self.table_name} WHERE 1=1 "
        if user_name:
            sql += f"AND name = '{user_name}' COLLATE utf8mb4_bin "
        if user_union_id:
            sql += f"AND wxid = '{user_union_id}' "
        data = self.db.select(sql, pymysql.cursors.DictCursor)
        return data


class MySQLUserRelationManager(UserRelationManager):
    def __init__(self, agent_db_config, wecom_db_config,
                 agent_staff_table, agent_user_table,
                 staff_table, relation_table, user_table):
        # FIXME(zhoutian): 因为现在数据库表不统一,需要从两个库读取
        self.agent_db = MySQLManager(agent_db_config)
        self.wecom_db = MySQLManager(wecom_db_config)
        self.agent_staff_table = agent_staff_table
        self.staff_table = staff_table
        self.relation_table = relation_table
        self.agent_user_table = agent_user_table
        self.user_table = user_table

    def list_staffs(self):
        sql = f"SELECT third_party_user_id, name, wxid, agent_name FROM {self.agent_staff_table} WHERE status = 1"
        data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
        return data

    def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
        return []

    def list_staff_users(self):
        # FIXME(zhoutian)
        # 测试期间逻辑,只取一个账号
        sql = (f"SELECT third_party_user_id, wxid FROM {self.agent_staff_table} WHERE status = 1"
               f" AND third_party_user_id in ('1688854492669990', '1688855931724582')")
        agent_staff_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
        if not agent_staff_data:
            return []
        ret = []
        for agent_staff in agent_staff_data:
            wxid = agent_staff['wxid']
            sql = f"SELECT id FROM {self.staff_table} WHERE carrier_id = '{wxid}'"
            staff_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
            if not staff_data:
                logger.error(f"staff[{wxid}] not found in wecom database")
                continue
            staff_id = staff_data[0]['id']
            sql = f"SELECT user_id FROM {self.relation_table} WHERE staff_id = '{staff_id}' AND is_delete = 0"
            user_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
            if not user_data:
                logger.warning(f"staff[{wxid}] has no user")
                continue
            user_ids = tuple(user['user_id'] for user in user_data)
            # FIXME(zhoutian): 测试期间临时逻辑
            if agent_staff['third_party_user_id'] == '1688854492669990':
                sql = f"SELECT union_id FROM {self.user_table} WHERE id IN {str(user_ids)} AND union_id is not null"
            else:
                sql = f"SELECT union_id FROM {self.user_table} WHERE id IN {str(user_ids)} AND union_id is not null" \
                      f" AND id in (SELECT distinct user_id FROM we_com_user_with_tag WHERE tag_id = 15 and is_delete = 0)"
            user_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
            if not user_data:
                logger.warning(f"staff[{wxid}] users not found in wecom database")
                continue
            user_union_ids = tuple(user['union_id'] for user in user_data)
            batch_size = 100
            n_batches = (len(user_union_ids) + batch_size - 1) // batch_size
            agent_user_data = []
            for i in range(n_batches):
                idx_begin = i * batch_size
                idx_end = min((i + 1) * batch_size, len(user_union_ids))
                batch_union_ids = user_union_ids[idx_begin:idx_end]
                sql = f"SELECT third_party_user_id, wxid FROM {self.agent_user_table} WHERE wxid IN {str(batch_union_ids)}"
                batch_agent_user_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
                if len(agent_user_data) != len(batch_union_ids):
                    # logger.debug(f"staff[{wxid}] some users not found in agent database")
                    pass
                agent_user_data.extend(batch_agent_user_data)
            staff_user_pairs = [
                {
                    'staff_id': agent_staff['third_party_user_id'],
                    'user_id': agent_user['third_party_user_id']
                }
                for agent_user in agent_user_data
            ]
            ret.extend(staff_user_pairs)
        return ret

    def get_user_tags(self, user_id: str) -> List[str]:
        sql = f"SELECT wxid FROM {self.agent_user_table} WHERE third_party_user_id = '{user_id}' AND wxid is not null"
        user_data = self.agent_db.select(sql, pymysql.cursors.DictCursor)
        if not user_data:
            logger.error(f"user[{user_id}] has no wxid")
            return []
        user_wxid = user_data[0]['wxid']
        sql = f"""
            select b.tag_id, c.`tag_name`  from `we_com_user` as a
              join `we_com_user_with_tag` as b
              join `we_com_tag` as c
              on a.`id` = b.`user_id`
              and b.`tag_id` = c.id
              where a.union_id = '{user_wxid}' """
        tag_data = self.wecom_db.select(sql, pymysql.cursors.DictCursor)
        tag_names = [tag['tag_name'] for tag in tag_data]
        return tag_names


if __name__ == '__main__':
    config = configs.get()
    user_db_config = config['storage']['user']
    staff_db_config = config['storage']['staff']
    user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
    user_profile = user_manager.get_user_profile('7881301263964433')
    print(user_profile)

    wecom_db_config = config['storage']['user_relation']
    user_relation_manager = MySQLUserRelationManager(
        user_db_config['mysql'], wecom_db_config['mysql'],
        config['storage']['staff']['table'],
        user_db_config['table'],
        wecom_db_config['table']['staff'],
        wecom_db_config['table']['relation'],
        wecom_db_config['table']['user']
    )
    # all_staff_users = user_relation_manager.list_staff_users()
    user_tags = user_relation_manager.get_user_tags('7881302078008656')
    print(user_tags)