""" generate category for given title """ import time import concurrent import traceback from concurrent.futures import ThreadPoolExecutor from pymysql.cursors import DictCursor from tqdm import tqdm from applications import log from applications.api.deep_seek_api_official import fetch_deepseek_completion from applications.const import CategoryGenerationTaskConst from applications.db import DatabaseConnector from applications.utils import yield_batch from config import long_articles_config from tasks.ai_tasks.prompts import category_generation_from_title class CategoryGenerationTask: def __init__(self): self.db_client = DatabaseConnector(long_articles_config) self.db_client.connect() self.const = CategoryGenerationTaskConst() def set_category_status_as_success( self, thread_db_client: DatabaseConnector, article_id: int, category: str ) -> int: update_query = f""" update publish_single_video_source set category = %s, category_status = %s, category_status_update_ts = %s where id = %s and category_status = %s; """ update_rows = thread_db_client.save( query=update_query, params=( category, self.const.SUCCESS_STATUS, int(time.time()), article_id, self.const.PROCESSING_STATUS, ), ) return update_rows def set_category_status_as_fail( self, thread_db_client: DatabaseConnector, article_id: int ) -> int: update_query = f""" update publish_single_video_source set category_status = %s, category_status_update_ts = %s where id = %s and category_status = %s; """ update_rows = thread_db_client.save( query=update_query, params=( self.const.FAIL_STATUS, int(time.time()), article_id, self.const.PROCESSING_STATUS, ), ) return update_rows def update_title_category( self, thread_db_client: DatabaseConnector, article_id: int, completion: dict ): try: category = completion.get(str(article_id)) self.set_category_status_as_success(thread_db_client, article_id, category) except Exception as e: log( task=self.const.TASK_NAME, function="update_each_record_status", message="AI返回格式失败,更新状态为失败", data={ "article_id": article_id, "error": str(e), "traceback": traceback.format_exc(), }, ) self.set_category_status_as_fail(thread_db_client, article_id) def rollback_lock_tasks(self) -> int: update_query = f""" update publish_single_video_source set category_status = %s where category_status = %s and category_status_update_ts <= %s; """ update_rows = self.db_client.save( query=update_query, params=( self.const.INIT_STATUS, self.const.PROCESSING_STATUS, int(time.time()) - self.const.MAX_PROCESSING_TIME, ), ) return update_rows def lock_task( self, thread_db_client: DatabaseConnector, article_id_tuple: tuple[int, ...] ) -> int: update_query = f""" update publish_single_video_source set category_status = %s, category_status_update_ts = %s where id in %s and category_status = %s; """ update_rows = thread_db_client.save( query=update_query, params=( self.const.PROCESSING_STATUS, int(time.time()), article_id_tuple, self.const.INIT_STATUS, ), ) return update_rows def deal_each_article(self, thread_db_client, article: dict): """ deal each article """ article_id = article["id"] title = article["article_title"] title_batch = [(article_id, title)] prompt = category_generation_from_title(title_batch) try: completion = fetch_deepseek_completion( model="DeepSeek-V3", prompt=prompt, output_type="json" ) self.update_title_category(thread_db_client, article_id, completion) except Exception as e: log( task=self.const.TASK_NAME, message="该文章存在敏感词,AI 拒绝返回", function="deal_each_article", data={ "article_id": article_id, "error": str(e), "traceback": traceback.format_exc(), }, ) self.set_category_status_as_fail(thread_db_client, article_id) def deal_batch_in_each_thread(self, task_batch: list[dict]): """ deal in each thread """ thread_db_client = DatabaseConnector(long_articles_config) thread_db_client.connect() title_batch = [(i["id"], i["article_title"]) for i in task_batch] id_tuple = tuple([int(i["id"]) for i in task_batch]) lock_rows = self.lock_task(thread_db_client, id_tuple) if lock_rows: prompt = category_generation_from_title(title_batch) try: completion = fetch_deepseek_completion( model="DeepSeek-V3", prompt=prompt, output_type="json" ) except Exception as e: log( task=self.const.TASK_NAME, function="category_generation_task", message=" batch 中存在敏感词,AI 拒绝返回", data={ "article_id": id_tuple, "error": str(e), "traceback": traceback.format_exc(), }, ) for article in tqdm(task_batch): self.deal_each_article(thread_db_client, article) return for article in title_batch: self.update_title_category(thread_db_client, article[0], completion) else: return def get_task_list(self): """ get task_list from a database """ fetch_query = f""" select id, article_title from publish_single_video_source where category_status = %s and bad_status = %s order by score desc; """ fetch_result = self.db_client.fetch( query=fetch_query, cursor_type=DictCursor, params=(self.const.INIT_STATUS, self.const.ARTICLE_GOOD_STATUS), ) return fetch_result def deal(self): self.rollback_lock_tasks() task_list = self.get_task_list() task_batch_list = yield_batch(data=task_list, batch_size=self.const.BATCH_SIZE) ## dev # for task_batch in task_batch_list: # self.deal_batch_in_each_thread(task_batch) with ThreadPoolExecutor(max_workers=self.const.MAX_WORKERS) as executor: futures = [ executor.submit(self.deal_batch_in_each_thread, task_batch) for task_batch in task_batch_list ] for _ in tqdm( concurrent.futures.as_completed(futures), total=len(futures), desc="Processing batches", ): pass