""" generate category for given title """ import concurrent from concurrent.futures import ThreadPoolExecutor from pymysql.cursors import DictCursor from tqdm import tqdm from applications.api.deep_seek_api_official import fetch_deepseek_completion from applications.const import CategoryGenerationTaskConst from applications.db import DatabaseConnector from applications.utils import yield_batch from config import long_articles_config from tasks.ai_tasks.prompts import category_generation_from_title class CategoryGenerationTask: """ generate category for given title """ def __init__(self): self.db_client = DatabaseConnector(long_articles_config) self.db_client.connect() self.const = CategoryGenerationTaskConst() def set_category(self, thread_db_client, article_id, category): """ set category for given article """ update_query = f""" update publish_single_video_source set category = %s, category_status = %s where id = %s and category_status = %s; """ update_rows = thread_db_client.save( query=update_query, params=(category, self.const.SUCCESS_STATUS, article_id, self.const.PROCESSING_STATUS) ) return update_rows def lock_task(self, thread_db_client, article_id_tuple): """ lock_task """ update_query = f""" update publish_single_video_source set category_status = %s where id in %s and category_status = %s; """ update_rows = thread_db_client.save( query=update_query, params=(self.const.PROCESSING_STATUS, article_id_tuple, self.const.INIT_STATUS) ) return update_rows def deal_in_each_thread(self, task_batch): try: thread_db_client = DatabaseConnector(long_articles_config) thread_db_client.connect() title_batch = [(i['id'], i['article_title']) for i in task_batch] id_tuple = tuple([i['id'] for i in task_batch]) # lock task lock_rows = self.lock_task(thread_db_client, id_tuple) if lock_rows: prompt = category_generation_from_title(title_batch) # print(prompt) completion = fetch_deepseek_completion(model="DeepSeek-V3", prompt=prompt, output_type='json') for article in title_batch: article_id = str(article[0]) category = completion.get(article_id) self.set_category(thread_db_client, article_id, category) else: return except Exception as e: print(e) def get_task_list(self): """ get task_list from a database """ fetch_query = f""" select id, article_title from publish_single_video_source where category_status = 0 and bad_status = 0 order by score desc; """ fetch_result = self.db_client.fetch(fetch_query, cursor_type=DictCursor) return fetch_result def deal(self): task_list = self.get_task_list() task_batch_list = yield_batch(data=task_list, batch_size=self.const.BATCH_SIZE) max_workers = 5 with ThreadPoolExecutor(max_workers=max_workers) as executor: # 提交所有任务到线程池 futures = [executor.submit(self.deal_in_each_thread, task_batch) for task_batch in task_batch_list] # 用 tqdm 跟踪任务完成进度 for _ in tqdm( concurrent.futures.as_completed(futures), total=len(futures), desc="Processing batches" ): pass # 仅用于更新进度条,不需要结果 if __name__ == '__main__': category_generation_task = CategoryGenerationTask() category_generation_task.deal()