category_generation.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. """
  2. generate category for given title
  3. """
  4. import concurrent
  5. from concurrent.futures import ThreadPoolExecutor
  6. from pymysql.cursors import DictCursor
  7. from tqdm import tqdm
  8. from applications.api.deep_seek_api_official import fetch_deepseek_completion
  9. from applications.const import CategoryGenerationTaskConst
  10. from applications.db import DatabaseConnector
  11. from applications.utils import yield_batch
  12. from config import long_articles_config
  13. from tasks.ai_tasks.prompts import category_generation_from_title
  14. class CategoryGenerationTask:
  15. """
  16. generate category for given title
  17. """
  18. def __init__(self):
  19. self.db_client = DatabaseConnector(long_articles_config)
  20. self.db_client.connect()
  21. self.const = CategoryGenerationTaskConst()
  22. def set_category(self, thread_db_client, article_id, category):
  23. """
  24. set category for given article
  25. """
  26. update_query = f"""
  27. update publish_single_video_source
  28. set category = %s, category_status = %s
  29. where id = %s and category_status = %s;
  30. """
  31. update_rows = thread_db_client.save(
  32. query=update_query,
  33. params=(category, self.const.SUCCESS_STATUS, article_id, self.const.PROCESSING_STATUS)
  34. )
  35. return update_rows
  36. def lock_task(self, thread_db_client, article_id_tuple):
  37. """
  38. lock_task
  39. """
  40. update_query = f"""
  41. update publish_single_video_source
  42. set category_status = %s
  43. where id in %s and category_status = %s;
  44. """
  45. update_rows = thread_db_client.save(
  46. query=update_query,
  47. params=(self.const.PROCESSING_STATUS, article_id_tuple, self.const.INIT_STATUS)
  48. )
  49. return update_rows
  50. def deal_in_each_thread(self, task_batch):
  51. try:
  52. thread_db_client = DatabaseConnector(long_articles_config)
  53. thread_db_client.connect()
  54. title_batch = [(i['id'], i['article_title']) for i in task_batch]
  55. id_tuple = tuple([i['id'] for i in task_batch])
  56. # lock task
  57. lock_rows = self.lock_task(thread_db_client, id_tuple)
  58. if lock_rows:
  59. prompt = category_generation_from_title(title_batch)
  60. # print(prompt)
  61. completion = fetch_deepseek_completion(model="DeepSeek-V3", prompt=prompt, output_type='json')
  62. for article in title_batch:
  63. article_id = str(article[0])
  64. category = completion.get(article_id)
  65. self.set_category(thread_db_client, article_id, category)
  66. else:
  67. return
  68. except Exception as e:
  69. print(e)
  70. def get_task_list(self):
  71. """
  72. get task_list from a database
  73. """
  74. fetch_query = f"""
  75. select id, article_title from publish_single_video_source
  76. where category_status = 0 and bad_status = 0
  77. order by score desc;
  78. """
  79. fetch_result = self.db_client.fetch(fetch_query, cursor_type=DictCursor)
  80. return fetch_result
  81. def deal(self):
  82. task_list = self.get_task_list()
  83. task_batch_list = yield_batch(data=task_list, batch_size=self.const.BATCH_SIZE)
  84. max_workers = 5
  85. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  86. # 提交所有任务到线程池
  87. futures = [executor.submit(self.deal_in_each_thread, task_batch)
  88. for task_batch in task_batch_list]
  89. # 用 tqdm 跟踪任务完成进度
  90. for _ in tqdm(
  91. concurrent.futures.as_completed(futures),
  92. total=len(futures),
  93. desc="Processing batches"
  94. ):
  95. pass # 仅用于更新进度条,不需要结果
  96. if __name__ == '__main__':
  97. category_generation_task = CategoryGenerationTask()
  98. category_generation_task.deal()