category_generation_task.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. """
  2. generate category for given title
  3. """
  4. import time
  5. import concurrent
  6. import traceback
  7. from concurrent.futures import ThreadPoolExecutor
  8. from pymysql.cursors import DictCursor
  9. from tqdm import tqdm
  10. from applications import log
  11. from applications.api.deep_seek_api_official import fetch_deepseek_completion
  12. from applications.const import CategoryGenerationTaskConst
  13. from applications.db import DatabaseConnector
  14. from applications.utils import yield_batch
  15. from config import long_articles_config
  16. from tasks.ai_tasks.prompts import category_generation_from_title
  17. class CategoryGenerationTask:
  18. def __init__(self):
  19. self.db_client = DatabaseConnector(long_articles_config)
  20. self.db_client.connect()
  21. self.const = CategoryGenerationTaskConst()
  22. def set_category_status_as_success(
  23. self, thread_db_client: DatabaseConnector, article_id: int, category: str
  24. ) -> int:
  25. update_query = f"""
  26. update publish_single_video_source
  27. set category = %s, category_status = %s, category_status_update_ts = %s
  28. where id = %s and category_status = %s;
  29. """
  30. update_rows = thread_db_client.save(
  31. query=update_query,
  32. params=(
  33. category,
  34. self.const.SUCCESS_STATUS,
  35. int(time.time()),
  36. article_id,
  37. self.const.PROCESSING_STATUS,
  38. ),
  39. )
  40. return update_rows
  41. def set_category_status_as_fail(
  42. self, thread_db_client: DatabaseConnector, article_id: int
  43. ) -> int:
  44. update_query = f"""
  45. update publish_single_video_source
  46. set category_status = %s, category_status_update_ts = %s
  47. where id = %s and category_status = %s;
  48. """
  49. update_rows = thread_db_client.save(
  50. query=update_query,
  51. params=(self.const.FAIL_STATUS, int(time.time()), article_id, self.const.PROCESSING_STATUS),
  52. )
  53. return update_rows
  54. def update_title_category(self, thread_db_client: DatabaseConnector, article_id: int, completion: dict):
  55. try:
  56. category = completion.get(str(article_id))
  57. self.set_category_status_as_success(thread_db_client, article_id, category)
  58. except Exception as e:
  59. log(
  60. task=self.const.TASK_NAME,
  61. function="update_each_record_status",
  62. message="AI返回格式失败,更新状态为失败",
  63. data={
  64. "article_id": article_id,
  65. "error": str(e),
  66. "traceback": traceback.format_exc()
  67. }
  68. )
  69. self.set_category_status_as_fail(thread_db_client, article_id)
  70. def rollback_lock_tasks(self) -> int:
  71. update_query = f"""
  72. update publish_single_video_source
  73. set category_status = %s
  74. where category_status = %s and category_status_update_ts <= %s;
  75. """
  76. update_rows = self.db_client.save(
  77. query=update_query,
  78. params=(
  79. self.const.INIT_STATUS,
  80. self.const.PROCESSING_STATUS,
  81. int(time.time()) - self.const.MAX_PROCESSING_TIME
  82. )
  83. )
  84. return update_rows
  85. def lock_task(
  86. self, thread_db_client: DatabaseConnector, article_id_tuple: tuple[int, ...]
  87. ) -> int:
  88. update_query = f"""
  89. update publish_single_video_source
  90. set category_status = %s, category_status_update_ts = %s
  91. where id in %s and category_status = %s;
  92. """
  93. update_rows = thread_db_client.save(
  94. query=update_query,
  95. params=(
  96. self.const.PROCESSING_STATUS,
  97. int(time.time()),
  98. article_id_tuple,
  99. self.const.INIT_STATUS,
  100. ),
  101. )
  102. return update_rows
  103. def deal_each_article(self, thread_db_client, article: dict):
  104. """
  105. deal each article
  106. """
  107. article_id = article["id"]
  108. title = article["article_title"]
  109. id_tuple = (article_id, )
  110. title_batch = [(article_id, title)]
  111. lock_rows = self.lock_task(thread_db_client, id_tuple)
  112. if lock_rows:
  113. prompt = category_generation_from_title(title_batch)
  114. try:
  115. completion = fetch_deepseek_completion(
  116. model="DeepSeek-V3", prompt=prompt, output_type="json"
  117. )
  118. self.update_title_category(thread_db_client, article_id, completion)
  119. except Exception as e:
  120. log(
  121. task=self.const.TASK_NAME,
  122. message="该文章存在敏感词,AI 拒绝返回",
  123. function="deal_each_article",
  124. data={
  125. "article_id": article_id,
  126. "error": str(e),
  127. "traceback": traceback.format_exc()
  128. }
  129. )
  130. self.set_category_status_as_fail(thread_db_client, article_id)
  131. def deal_batch_in_each_thread(self, task_batch: list[dict]):
  132. """
  133. deal in each thread
  134. """
  135. thread_db_client = DatabaseConnector(long_articles_config)
  136. thread_db_client.connect()
  137. title_batch = [(i["id"], i["article_title"]) for i in task_batch]
  138. id_tuple = tuple([int(i["id"]) for i in task_batch])
  139. lock_rows = self.lock_task(thread_db_client, id_tuple)
  140. if lock_rows:
  141. prompt = category_generation_from_title(title_batch)
  142. try:
  143. completion = fetch_deepseek_completion(
  144. model="DeepSeek-V3", prompt=prompt, output_type="json"
  145. )
  146. except Exception as e:
  147. log(
  148. task=self.const.TASK_NAME,
  149. function="category_generation_task",
  150. message=" batch 中存在敏感词,AI 拒绝返回",
  151. data={
  152. "article_id": id_tuple,
  153. "error": str(e),
  154. "traceback": traceback.format_exc()
  155. }
  156. )
  157. for article in tqdm(task_batch):
  158. self.deal_each_article(thread_db_client, article)
  159. return
  160. for article in title_batch:
  161. self.update_title_category(thread_db_client, article[0], completion)
  162. else:
  163. return
  164. def get_task_list(self):
  165. """
  166. get task_list from a database
  167. """
  168. fetch_query = f"""
  169. select id, article_title from publish_single_video_source
  170. where category_status = %s and bad_status = %s
  171. order by score desc limit 20;
  172. """
  173. fetch_result = self.db_client.fetch(query=fetch_query, cursor_type=DictCursor, params=(self.const.INIT_STATUS, self.const.ARTICLE_GOOD_STATUS))
  174. return fetch_result
  175. def deal(self):
  176. self.rollback_lock_tasks()
  177. task_list = self.get_task_list()
  178. task_batch_list = yield_batch(data=task_list, batch_size=self.const.BATCH_SIZE)
  179. for task_batch in task_batch_list:
  180. self.deal_batch_in_each_thread(task_batch)
  181. # with ThreadPoolExecutor(max_workers=self.const.MAX_WORKERS) as executor:
  182. # futures = [
  183. # executor.submit(self.deal_in_each_thread, task_batch)
  184. # for task_batch in task_batch_list
  185. # ]
  186. #
  187. # for _ in tqdm(
  188. # concurrent.futures.as_completed(futures),
  189. # total=len(futures),
  190. # desc="Processing batches",
  191. # ):
  192. # pass
  193. if __name__ == "__main__":
  194. category_generation_task = CategoryGenerationTask()
  195. category_generation_task.deal()