category_generation_task.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. """
  2. generate category for given title
  3. """
  4. import time
  5. import concurrent
  6. import traceback
  7. from concurrent.futures import ThreadPoolExecutor
  8. from pymysql.cursors import DictCursor
  9. from tqdm import tqdm
  10. from applications import log
  11. from applications.api.deep_seek_api_official import fetch_deepseek_completion
  12. from applications.const import CategoryGenerationTaskConst
  13. from applications.db import DatabaseConnector
  14. from applications.utils import yield_batch
  15. from config import long_articles_config
  16. from tasks.ai_tasks.prompts import category_generation_from_title
  17. class CategoryGenerationTask:
  18. def __init__(self):
  19. self.db_client = DatabaseConnector(long_articles_config)
  20. self.db_client.connect()
  21. self.const = CategoryGenerationTaskConst()
  22. def set_category_status_as_success(
  23. self, thread_db_client: DatabaseConnector, article_id: int, category: str
  24. ) -> int:
  25. update_query = f"""
  26. update publish_single_video_source
  27. set category = %s, category_status = %s, category_status_update_ts = %s
  28. where id = %s and category_status = %s;
  29. """
  30. update_rows = thread_db_client.save(
  31. query=update_query,
  32. params=(
  33. category,
  34. self.const.SUCCESS_STATUS,
  35. int(time.time()),
  36. article_id,
  37. self.const.PROCESSING_STATUS,
  38. ),
  39. )
  40. return update_rows
  41. def set_category_status_as_fail(
  42. self, thread_db_client: DatabaseConnector, article_id: int
  43. ) -> int:
  44. update_query = f"""
  45. update publish_single_video_source
  46. set category_status = %s, category_status_update_ts = %s
  47. where id = %s and category_status = %s;
  48. """
  49. update_rows = thread_db_client.save(
  50. query=update_query,
  51. params=(
  52. self.const.FAIL_STATUS,
  53. int(time.time()),
  54. article_id,
  55. self.const.PROCESSING_STATUS,
  56. ),
  57. )
  58. return update_rows
  59. def update_title_category(
  60. self, thread_db_client: DatabaseConnector, article_id: int, completion: dict
  61. ):
  62. try:
  63. category = completion.get(str(article_id))
  64. self.set_category_status_as_success(thread_db_client, article_id, category)
  65. except Exception as e:
  66. log(
  67. task=self.const.TASK_NAME,
  68. function="update_each_record_status",
  69. message="AI返回格式失败,更新状态为失败",
  70. data={
  71. "article_id": article_id,
  72. "error": str(e),
  73. "traceback": traceback.format_exc(),
  74. },
  75. )
  76. self.set_category_status_as_fail(thread_db_client, article_id)
  77. def rollback_lock_tasks(self) -> int:
  78. update_query = f"""
  79. update publish_single_video_source
  80. set category_status = %s
  81. where category_status = %s and category_status_update_ts <= %s;
  82. """
  83. update_rows = self.db_client.save(
  84. query=update_query,
  85. params=(
  86. self.const.INIT_STATUS,
  87. self.const.PROCESSING_STATUS,
  88. int(time.time()) - self.const.MAX_PROCESSING_TIME,
  89. ),
  90. )
  91. return update_rows
  92. def lock_task(
  93. self, thread_db_client: DatabaseConnector, article_id_tuple: tuple[int, ...]
  94. ) -> int:
  95. update_query = f"""
  96. update publish_single_video_source
  97. set category_status = %s, category_status_update_ts = %s
  98. where id in %s and category_status = %s;
  99. """
  100. update_rows = thread_db_client.save(
  101. query=update_query,
  102. params=(
  103. self.const.PROCESSING_STATUS,
  104. int(time.time()),
  105. article_id_tuple,
  106. self.const.INIT_STATUS,
  107. ),
  108. )
  109. return update_rows
  110. def deal_each_article(self, thread_db_client, article: dict):
  111. """
  112. deal each article
  113. """
  114. article_id = article["id"]
  115. title = article["article_title"]
  116. title_batch = [(article_id, title)]
  117. prompt = category_generation_from_title(title_batch)
  118. try:
  119. completion = fetch_deepseek_completion(
  120. model="DeepSeek-V3", prompt=prompt, output_type="json"
  121. )
  122. self.update_title_category(thread_db_client, article_id, completion)
  123. except Exception as e:
  124. log(
  125. task=self.const.TASK_NAME,
  126. message="该文章存在敏感词,AI 拒绝返回",
  127. function="deal_each_article",
  128. data={
  129. "article_id": article_id,
  130. "error": str(e),
  131. "traceback": traceback.format_exc(),
  132. },
  133. )
  134. self.set_category_status_as_fail(thread_db_client, article_id)
  135. def deal_batch_in_each_thread(self, task_batch: list[dict]):
  136. """
  137. deal in each thread
  138. """
  139. thread_db_client = DatabaseConnector(long_articles_config)
  140. thread_db_client.connect()
  141. title_batch = [(i["id"], i["article_title"]) for i in task_batch]
  142. id_tuple = tuple([int(i["id"]) for i in task_batch])
  143. lock_rows = self.lock_task(thread_db_client, id_tuple)
  144. if lock_rows:
  145. prompt = category_generation_from_title(title_batch)
  146. try:
  147. completion = fetch_deepseek_completion(
  148. model="DeepSeek-V3", prompt=prompt, output_type="json"
  149. )
  150. except Exception as e:
  151. log(
  152. task=self.const.TASK_NAME,
  153. function="category_generation_task",
  154. message=" batch 中存在敏感词,AI 拒绝返回",
  155. data={
  156. "article_id": id_tuple,
  157. "error": str(e),
  158. "traceback": traceback.format_exc(),
  159. },
  160. )
  161. for article in tqdm(task_batch):
  162. self.deal_each_article(thread_db_client, article)
  163. return
  164. for article in title_batch:
  165. self.update_title_category(thread_db_client, article[0], completion)
  166. else:
  167. return
  168. def get_task_list(self):
  169. """
  170. get task_list from a database
  171. """
  172. fetch_query = f"""
  173. select id, article_title from publish_single_video_source
  174. where category_status = %s and bad_status = %s
  175. order by score desc;
  176. """
  177. fetch_result = self.db_client.fetch(
  178. query=fetch_query,
  179. cursor_type=DictCursor,
  180. params=(self.const.INIT_STATUS, self.const.ARTICLE_GOOD_STATUS),
  181. )
  182. return fetch_result
  183. def deal(self):
  184. self.rollback_lock_tasks()
  185. task_list = self.get_task_list()
  186. task_batch_list = yield_batch(data=task_list, batch_size=self.const.BATCH_SIZE)
  187. ## dev
  188. # for task_batch in task_batch_list:
  189. # self.deal_batch_in_each_thread(task_batch)
  190. with ThreadPoolExecutor(max_workers=self.const.MAX_WORKERS) as executor:
  191. futures = [
  192. executor.submit(self.deal_batch_in_each_thread, task_batch)
  193. for task_batch in task_batch_list
  194. ]
  195. for _ in tqdm(
  196. concurrent.futures.as_completed(futures),
  197. total=len(futures),
  198. desc="Processing batches",
  199. ):
  200. pass