| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444 | """generate category for given title"""import timeimport concurrentimport tracebackfrom concurrent.futures import ThreadPoolExecutorfrom pymysql.cursors import DictCursorfrom tqdm import tqdmfrom applications import logfrom applications.api.deep_seek_api_official import fetch_deepseek_completionfrom applications.const import CategoryGenerationTaskConstfrom applications.db import DatabaseConnectorfrom applications.utils import yield_batchfrom config import long_articles_configfrom tasks.ai_tasks.prompts import category_generation_from_titleclass CategoryGenerationTask:    def __init__(self):        self.db_client = DatabaseConnector(long_articles_config)        self.db_client.connect()        self.const = CategoryGenerationTaskConst()    def rollback_lock_tasks(self, table_name) -> int:        """        回滚锁定的任务        :param table_name:        """        update_query = f"""               update {table_name}               set category_status = %s               where category_status = %s and category_status_update_ts <= %s;           """        update_rows = self.db_client.save(            query=update_query,            params=(                self.const.INIT_STATUS,                self.const.PROCESSING_STATUS,                int(time.time()) - self.const.MAX_PROCESSING_TIME,            ),        )        return update_rowsclass VideoPoolCategoryGenerationTask(CategoryGenerationTask):    def set_category_status_as_success(        self, thread_db_client: DatabaseConnector, article_id: int, category: str    ) -> int:        update_query = f"""            update publish_single_video_source            set category = %s, category_status = %s, category_status_update_ts = %s            where id = %s and category_status = %s;        """        update_rows = thread_db_client.save(            query=update_query,            params=(                category,                self.const.SUCCESS_STATUS,                int(time.time()),                article_id,                self.const.PROCESSING_STATUS,            ),        )        return update_rows    def set_category_status_as_fail(        self, thread_db_client: DatabaseConnector, article_id: int    ) -> int:        update_query = f"""            update publish_single_video_source            set category_status = %s, category_status_update_ts = %s            where id = %s and category_status = %s;        """        update_rows = thread_db_client.save(            query=update_query,            params=(                self.const.FAIL_STATUS,                int(time.time()),                article_id,                self.const.PROCESSING_STATUS,            ),        )        return update_rows    def update_title_category(        self, thread_db_client: DatabaseConnector, article_id: int, completion: dict    ):        try:            category = completion.get(str(article_id))            self.set_category_status_as_success(thread_db_client, article_id, category)        except Exception as e:            log(                task=self.const.TASK_NAME,                function="update_each_record_status",                message="AI返回格式失败,更新状态为失败",                data={                    "article_id": article_id,                    "error": str(e),                    "traceback": traceback.format_exc(),                },            )            self.set_category_status_as_fail(thread_db_client, article_id)    def lock_task(        self, thread_db_client: DatabaseConnector, article_id_tuple: tuple[int, ...]    ) -> int:        update_query = f"""            update publish_single_video_source            set category_status = %s, category_status_update_ts = %s            where id in %s and category_status = %s;        """        update_rows = thread_db_client.save(            query=update_query,            params=(                self.const.PROCESSING_STATUS,                int(time.time()),                article_id_tuple,                self.const.INIT_STATUS,            ),        )        return update_rows    def deal_each_article(self, thread_db_client, article: dict):        """        deal each article        """        article_id = article["id"]        title = article["article_title"]        title_batch = [(article_id, title)]        prompt = category_generation_from_title(title_batch)        try:            completion = fetch_deepseek_completion(                model="DeepSeek-V3", prompt=prompt, output_type="json"            )            self.update_title_category(thread_db_client, article_id, completion)        except Exception as e:            log(                task=self.const.TASK_NAME,                message="该文章存在敏感词,AI 拒绝返回",                function="deal_each_article",                data={                    "article_id": article_id,                    "error": str(e),                    "traceback": traceback.format_exc(),                },            )            self.set_category_status_as_fail(thread_db_client, article_id)    def deal_batch_in_each_thread(self, task_batch: list[dict]):        """        deal in each thread        """        thread_db_client = DatabaseConnector(long_articles_config)        thread_db_client.connect()        title_batch = [(i["id"], i["article_title"]) for i in task_batch]        id_tuple = tuple([int(i["id"]) for i in task_batch])        lock_rows = self.lock_task(thread_db_client, id_tuple)        if lock_rows:            prompt = category_generation_from_title(title_batch)            try:                completion = fetch_deepseek_completion(                    model="DeepSeek-V3", prompt=prompt, output_type="json"                )            except Exception as e:                log(                    task=self.const.TASK_NAME,                    function="category_generation_task",                    message=" batch 中存在敏感词,AI 拒绝返回",                    data={                        "article_id": id_tuple,                        "error": str(e),                        "traceback": traceback.format_exc(),                    },                )                for article in tqdm(task_batch):                    self.deal_each_article(thread_db_client, article)                return            for article in title_batch:                self.update_title_category(thread_db_client, article[0], completion)        else:            return    def get_task_list(self):        """        get task_list from a database        """        fetch_query = f"""            select id, article_title from publish_single_video_source            where category_status = %s and bad_status = %s            order by score desc;        """        fetch_result = self.db_client.fetch(            query=fetch_query,            cursor_type=DictCursor,            params=(self.const.INIT_STATUS, self.const.ARTICLE_GOOD_STATUS),        )        return fetch_result    def deal(self):        self.rollback_lock_tasks(self.const.VIDEO_TABLE_NAME)        task_list = self.get_task_list()        task_batch_list = yield_batch(data=task_list, batch_size=self.const.BATCH_SIZE)        ##  dev        # for task_batch in task_batch_list:        #     self.deal_batch_in_each_thread(task_batch)        with ThreadPoolExecutor(max_workers=self.const.MAX_WORKERS) as executor:            futures = [                executor.submit(self.deal_batch_in_each_thread, task_batch)                for task_batch in task_batch_list            ]            for _ in tqdm(                concurrent.futures.as_completed(futures),                total=len(futures),                desc="Processing batches",            ):                passclass ArticlePoolCategoryGenerationTask(CategoryGenerationTask):    def set_category_status_as_success(            self, thread_db_client: DatabaseConnector, article_id: int, category: str    ) -> int:        update_query = f"""               update {self.const.ARTICLE_TABLE_NAME}               set category_by_ai = %s, category_status = %s, category_status_update_ts = %s               where article_id = %s and category_status = %s;           """        update_rows = thread_db_client.save(            query=update_query,            params=(                category,                self.const.SUCCESS_STATUS,                int(time.time()),                article_id,                self.const.PROCESSING_STATUS,            ),        )        return update_rows    def set_category_status_as_fail(            self, thread_db_client: DatabaseConnector, article_id: int    ) -> int:        update_query = f"""               update {self.const.ARTICLE_TABLE_NAME}               set category_status = %s, category_status_update_ts = %s               where article_id = %s and category_status = %s;           """        update_rows = thread_db_client.save(            query=update_query,            params=(                self.const.FAIL_STATUS,                int(time.time()),                article_id,                self.const.PROCESSING_STATUS,            ),        )        return update_rows    def update_title_category(            self, thread_db_client: DatabaseConnector, article_id: int, completion: dict    ):        try:            category = completion.get(str(article_id))            self.set_category_status_as_success(thread_db_client, article_id, category)        except Exception as e:            log(                task=self.const.TASK_NAME,                function="update_each_record_status",                message="AI返回格式失败,更新状态为失败",                data={                    "article_id": article_id,                    "error": str(e),                    "traceback": traceback.format_exc(),                },            )            self.set_category_status_as_fail(thread_db_client, article_id)    def lock_task(            self, thread_db_client: DatabaseConnector, article_id_tuple: tuple[int, ...]    ) -> int:        update_query = f"""               update {self.const.ARTICLE_TABLE_NAME}               set category_status = %s, category_status_update_ts = %s               where article_id in %s and category_status = %s;           """        update_rows = thread_db_client.save(            query=update_query,            params=(                self.const.PROCESSING_STATUS,                int(time.time()),                article_id_tuple,                self.const.INIT_STATUS,            ),        )        return update_rows    def deal_each_article(self, thread_db_client, article: dict):        """        deal each article        """        article_id = article["article_id"]        title = article["title"]        title_batch = [(article_id, title)]        prompt = category_generation_from_title(title_batch)        try:            completion = fetch_deepseek_completion(                model="DeepSeek-V3", prompt=prompt, output_type="json"            )            self.update_title_category(thread_db_client, article_id, completion)        except Exception as e:            log(                task=self.const.TASK_NAME,                message="该文章存在敏感词,AI 拒绝返回",                function="deal_each_article",                data={                    "article_id": article_id,                    "error": str(e),                    "traceback": traceback.format_exc(),                },            )            self.set_category_status_as_fail(thread_db_client, article_id)    def deal_batch_in_each_thread(self, task_batch: list[dict]):        """        deal in each thread        """        thread_db_client = DatabaseConnector(long_articles_config)        thread_db_client.connect()        title_batch = [(i["article_id"], i["title"]) for i in task_batch]        id_tuple = tuple([int(i["article_id"]) for i in task_batch])        lock_rows = self.lock_task(thread_db_client, id_tuple)        if lock_rows:            prompt = category_generation_from_title(title_batch)            try:                completion = fetch_deepseek_completion(                    model="DeepSeek-V3", prompt=prompt, output_type="json"                )            except Exception as e:                log(                    task=self.const.TASK_NAME,                    function="article_category_generation_task",                    message="batch 中存在敏感词,AI 拒绝返回",                    data={                        "article_id": id_tuple,                        "error": str(e),                        "traceback": traceback.format_exc(),                    },                )                for article in tqdm(task_batch):                    self.deal_each_article(thread_db_client, article)                return            for article in title_batch:                self.update_title_category(thread_db_client, article[0], completion)        else:            return    def get_task_list(self):        """        get task_list from a database        """        fetch_query = f"""               select article_id, title from {self.const.ARTICLE_TABLE_NAME}               where category_status = %s and status = %s and score > %s               order by score desc limit 1000;           """        fetch_result = self.db_client.fetch(            query=fetch_query,            cursor_type=DictCursor,            params=(self.const.INIT_STATUS, self.const.ARTICLE_INIT_STATUS, self.const.LIMIT_SCORE),        )        return fetch_result    def deal(self):        self.rollback_lock_tasks(self.const.ARTICLE_TABLE_NAME)        task_list = self.get_task_list()        task_batch_list = yield_batch(data=task_list, batch_size=self.const.BATCH_SIZE)        # #  dev        # for task_batch in task_batch_list:        #     self.deal_batch_in_each_thread(task_batch)        with ThreadPoolExecutor(max_workers=self.const.MAX_WORKERS) as executor:            futures = [                executor.submit(self.deal_batch_in_each_thread, task_batch)                for task_batch in task_batch_list            ]            for _ in tqdm(                    concurrent.futures.as_completed(futures),                    total=len(futures),                    desc="Processing batches",            ):                pass
 |