"""
generate category for given title
"""

import time
import concurrent
import traceback
from concurrent.futures import ThreadPoolExecutor

from pymysql.cursors import DictCursor
from tqdm import tqdm

from applications import log
from applications.api.deep_seek_api_official import fetch_deepseek_completion
from applications.const import CategoryGenerationTaskConst
from applications.db import DatabaseConnector
from applications.utils import yield_batch

from config import long_articles_config
from tasks.ai_tasks.prompts import category_generation_from_title


class CategoryGenerationTask:

    def __init__(self):
        self.db_client = DatabaseConnector(long_articles_config)
        self.db_client.connect()
        self.const = CategoryGenerationTaskConst()

    def set_category_status_as_success(
        self, thread_db_client: DatabaseConnector, article_id: int, category: str
    ) -> int:

        update_query = f"""
            update publish_single_video_source
            set category = %s, category_status = %s, category_status_update_ts = %s
            where id = %s and category_status = %s;
        """

        update_rows = thread_db_client.save(
            query=update_query,
            params=(
                category,
                self.const.SUCCESS_STATUS,
                int(time.time()),
                article_id,
                self.const.PROCESSING_STATUS,
            ),
        )
        return update_rows

    def set_category_status_as_fail(
        self, thread_db_client: DatabaseConnector, article_id: int
    ) -> int:

        update_query = f"""
            update publish_single_video_source
            set category_status = %s, category_status_update_ts = %s
            where id = %s and category_status = %s;
        """

        update_rows = thread_db_client.save(
            query=update_query,
            params=(
                self.const.FAIL_STATUS,
                int(time.time()),
                article_id,
                self.const.PROCESSING_STATUS,
            ),
        )
        return update_rows

    def update_title_category(
        self, thread_db_client: DatabaseConnector, article_id: int, completion: dict
    ):

        try:
            category = completion.get(str(article_id))
            self.set_category_status_as_success(thread_db_client, article_id, category)

        except Exception as e:
            log(
                task=self.const.TASK_NAME,
                function="update_each_record_status",
                message="AI返回格式失败,更新状态为失败",
                data={
                    "article_id": article_id,
                    "error": str(e),
                    "traceback": traceback.format_exc(),
                },
            )
            self.set_category_status_as_fail(thread_db_client, article_id)

    def rollback_lock_tasks(self) -> int:

        update_query = f"""
            update publish_single_video_source  
            set category_status = %s
            where category_status = %s and category_status_update_ts <= %s;
        """

        update_rows = self.db_client.save(
            query=update_query,
            params=(
                self.const.INIT_STATUS,
                self.const.PROCESSING_STATUS,
                int(time.time()) - self.const.MAX_PROCESSING_TIME,
            ),
        )

        return update_rows

    def lock_task(
        self, thread_db_client: DatabaseConnector, article_id_tuple: tuple[int, ...]
    ) -> int:

        update_query = f"""
            update publish_single_video_source
            set category_status = %s, category_status_update_ts = %s
            where id in %s and category_status = %s;
        """

        update_rows = thread_db_client.save(
            query=update_query,
            params=(
                self.const.PROCESSING_STATUS,
                int(time.time()),
                article_id_tuple,
                self.const.INIT_STATUS,
            ),
        )
        return update_rows

    def deal_each_article(self, thread_db_client, article: dict):
        """
        deal each article
        """
        article_id = article["id"]
        title = article["article_title"]

        title_batch = [(article_id, title)]

        prompt = category_generation_from_title(title_batch)
        try:
            completion = fetch_deepseek_completion(
                model="DeepSeek-V3", prompt=prompt, output_type="json"
            )
            self.update_title_category(thread_db_client, article_id, completion)

        except Exception as e:
            log(
                task=self.const.TASK_NAME,
                message="该文章存在敏感词,AI 拒绝返回",
                function="deal_each_article",
                data={
                    "article_id": article_id,
                    "error": str(e),
                    "traceback": traceback.format_exc(),
                },
            )
            self.set_category_status_as_fail(thread_db_client, article_id)

    def deal_batch_in_each_thread(self, task_batch: list[dict]):
        """
        deal in each thread
        """
        thread_db_client = DatabaseConnector(long_articles_config)
        thread_db_client.connect()

        title_batch = [(i["id"], i["article_title"]) for i in task_batch]
        id_tuple = tuple([int(i["id"]) for i in task_batch])

        lock_rows = self.lock_task(thread_db_client, id_tuple)
        if lock_rows:
            prompt = category_generation_from_title(title_batch)

            try:
                completion = fetch_deepseek_completion(
                    model="DeepSeek-V3", prompt=prompt, output_type="json"
                )
            except Exception as e:
                log(
                    task=self.const.TASK_NAME,
                    function="category_generation_task",
                    message=" batch 中存在敏感词,AI 拒绝返回",
                    data={
                        "article_id": id_tuple,
                        "error": str(e),
                        "traceback": traceback.format_exc(),
                    },
                )
                for article in tqdm(task_batch):
                    self.deal_each_article(thread_db_client, article)

                return

            for article in title_batch:
                self.update_title_category(thread_db_client, article[0], completion)

        else:
            return

    def get_task_list(self):
        """
        get task_list from a database
        """
        fetch_query = f"""
            select id, article_title from publish_single_video_source
            where category_status = %s and bad_status = %s
            order by score desc;
        """
        fetch_result = self.db_client.fetch(
            query=fetch_query,
            cursor_type=DictCursor,
            params=(self.const.INIT_STATUS, self.const.ARTICLE_GOOD_STATUS),
        )
        return fetch_result

    def deal(self):

        self.rollback_lock_tasks()

        task_list = self.get_task_list()
        task_batch_list = yield_batch(data=task_list, batch_size=self.const.BATCH_SIZE)

        ##  dev
        # for task_batch in task_batch_list:
        #     self.deal_batch_in_each_thread(task_batch)

        with ThreadPoolExecutor(max_workers=self.const.MAX_WORKERS) as executor:
            futures = [
                executor.submit(self.deal_batch_in_each_thread, task_batch)
                for task_batch in task_batch_list
            ]

            for _ in tqdm(
                concurrent.futures.as_completed(futures),
                total=len(futures),
                desc="Processing batches",
            ):
                pass