123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116 |
- """
- generate category for given title
- """
- import concurrent
- from concurrent.futures import ThreadPoolExecutor
- from pymysql.cursors import DictCursor
- from tqdm import tqdm
- from applications.api.deep_seek_api_official import fetch_deepseek_completion
- from applications.const import CategoryGenerationTaskConst
- from applications.db import DatabaseConnector
- from applications.utils import yield_batch
- from config import long_articles_config
- from tasks.ai_tasks.prompts import category_generation_from_title
- class CategoryGenerationTask:
- """
- generate category for given title
- """
- def __init__(self):
- self.db_client = DatabaseConnector(long_articles_config)
- self.db_client.connect()
- self.const = CategoryGenerationTaskConst()
- def set_category(self, thread_db_client, article_id, category):
- """
- set category for given article
- """
- update_query = f"""
- update publish_single_video_source
- set category = %s, category_status = %s
- where id = %s and category_status = %s;
- """
- update_rows = thread_db_client.save(
- query=update_query,
- params=(category, self.const.SUCCESS_STATUS, article_id, self.const.PROCESSING_STATUS)
- )
- return update_rows
- def lock_task(self, thread_db_client, article_id_tuple):
- """
- lock_task
- """
- update_query = f"""
- update publish_single_video_source
- set category_status = %s
- where id in %s and category_status = %s;
- """
- update_rows = thread_db_client.save(
- query=update_query,
- params=(self.const.PROCESSING_STATUS, article_id_tuple, self.const.INIT_STATUS)
- )
- return update_rows
- def deal_in_each_thread(self, task_batch):
- try:
- thread_db_client = DatabaseConnector(long_articles_config)
- thread_db_client.connect()
- title_batch = [(i['id'], i['article_title']) for i in task_batch]
- id_tuple = tuple([i['id'] for i in task_batch])
- # lock task
- lock_rows = self.lock_task(thread_db_client, id_tuple)
- if lock_rows:
- prompt = category_generation_from_title(title_batch)
- # print(prompt)
- completion = fetch_deepseek_completion(model="DeepSeek-V3", prompt=prompt, output_type='json')
- for article in title_batch:
- article_id = str(article[0])
- category = completion.get(article_id)
- self.set_category(thread_db_client, article_id, category)
- else:
- return
- except Exception as e:
- print(e)
- def get_task_list(self):
- """
- get task_list from a database
- """
- fetch_query = f"""
- select id, article_title from publish_single_video_source
- where category_status = 0 and bad_status = 0
- order by score desc;
- """
- fetch_result = self.db_client.fetch(fetch_query, cursor_type=DictCursor)
- return fetch_result
- def deal(self):
- task_list = self.get_task_list()
- task_batch_list = yield_batch(data=task_list, batch_size=self.const.BATCH_SIZE)
- max_workers = 5
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
- # 提交所有任务到线程池
- futures = [executor.submit(self.deal_in_each_thread, task_batch)
- for task_batch in task_batch_list]
- # 用 tqdm 跟踪任务完成进度
- for _ in tqdm(
- concurrent.futures.as_completed(futures),
- total=len(futures),
- desc="Processing batches"
- ):
- pass # 仅用于更新进度条,不需要结果
- if __name__ == '__main__':
- category_generation_task = CategoryGenerationTask()
- category_generation_task.deal()
|