category_generation_task.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. """
  2. generate category for given title
  3. """
  4. import time
  5. import concurrent
  6. import traceback
  7. from concurrent.futures import ThreadPoolExecutor
  8. from pymysql.cursors import DictCursor
  9. from tqdm import tqdm
  10. from applications import log
  11. from applications.api.deep_seek_api_official import fetch_deepseek_completion
  12. from applications.const import CategoryGenerationTaskConst
  13. from applications.db import DatabaseConnector
  14. from applications.utils import yield_batch
  15. from config import long_articles_config
  16. from tasks.ai_tasks.prompts import category_generation_from_title
  17. class CategoryGenerationTask:
  18. def __init__(self):
  19. self.db_client = DatabaseConnector(long_articles_config)
  20. self.db_client.connect()
  21. self.const = CategoryGenerationTaskConst()
  22. def rollback_lock_tasks(self, table_name) -> int:
  23. """
  24. 回滚锁定的任务
  25. :param table_name:
  26. """
  27. update_query = f"""
  28. update {table_name}
  29. set category_status = %s
  30. where category_status = %s and category_status_update_ts <= %s;
  31. """
  32. update_rows = self.db_client.save(
  33. query=update_query,
  34. params=(
  35. self.const.INIT_STATUS,
  36. self.const.PROCESSING_STATUS,
  37. int(time.time()) - self.const.MAX_PROCESSING_TIME,
  38. ),
  39. )
  40. return update_rows
  41. class VideoPoolCategoryGenerationTask(CategoryGenerationTask):
  42. def set_category_status_as_success(
  43. self, thread_db_client: DatabaseConnector, article_id: int, category: str
  44. ) -> int:
  45. update_query = f"""
  46. update publish_single_video_source
  47. set category = %s, category_status = %s, category_status_update_ts = %s
  48. where id = %s and category_status = %s;
  49. """
  50. update_rows = thread_db_client.save(
  51. query=update_query,
  52. params=(
  53. category,
  54. self.const.SUCCESS_STATUS,
  55. int(time.time()),
  56. article_id,
  57. self.const.PROCESSING_STATUS,
  58. ),
  59. )
  60. return update_rows
  61. def set_category_status_as_fail(
  62. self, thread_db_client: DatabaseConnector, article_id: int
  63. ) -> int:
  64. update_query = f"""
  65. update publish_single_video_source
  66. set category_status = %s, category_status_update_ts = %s
  67. where id = %s and category_status = %s;
  68. """
  69. update_rows = thread_db_client.save(
  70. query=update_query,
  71. params=(
  72. self.const.FAIL_STATUS,
  73. int(time.time()),
  74. article_id,
  75. self.const.PROCESSING_STATUS,
  76. ),
  77. )
  78. return update_rows
  79. def update_title_category(
  80. self, thread_db_client: DatabaseConnector, article_id: int, completion: dict
  81. ):
  82. try:
  83. category = completion.get(str(article_id))
  84. self.set_category_status_as_success(thread_db_client, article_id, category)
  85. except Exception as e:
  86. log(
  87. task=self.const.TASK_NAME,
  88. function="update_each_record_status",
  89. message="AI返回格式失败,更新状态为失败",
  90. data={
  91. "article_id": article_id,
  92. "error": str(e),
  93. "traceback": traceback.format_exc(),
  94. },
  95. )
  96. self.set_category_status_as_fail(thread_db_client, article_id)
  97. def lock_task(
  98. self, thread_db_client: DatabaseConnector, article_id_tuple: tuple[int, ...]
  99. ) -> int:
  100. update_query = f"""
  101. update publish_single_video_source
  102. set category_status = %s, category_status_update_ts = %s
  103. where id in %s and category_status = %s;
  104. """
  105. update_rows = thread_db_client.save(
  106. query=update_query,
  107. params=(
  108. self.const.PROCESSING_STATUS,
  109. int(time.time()),
  110. article_id_tuple,
  111. self.const.INIT_STATUS,
  112. ),
  113. )
  114. return update_rows
  115. def deal_each_article(self, thread_db_client, article: dict):
  116. """
  117. deal each article
  118. """
  119. article_id = article["id"]
  120. title = article["article_title"]
  121. title_batch = [(article_id, title)]
  122. prompt = category_generation_from_title(title_batch)
  123. try:
  124. completion = fetch_deepseek_completion(
  125. model="DeepSeek-V3", prompt=prompt, output_type="json"
  126. )
  127. self.update_title_category(thread_db_client, article_id, completion)
  128. except Exception as e:
  129. log(
  130. task=self.const.TASK_NAME,
  131. message="该文章存在敏感词,AI 拒绝返回",
  132. function="deal_each_article",
  133. data={
  134. "article_id": article_id,
  135. "error": str(e),
  136. "traceback": traceback.format_exc(),
  137. },
  138. )
  139. self.set_category_status_as_fail(thread_db_client, article_id)
  140. def deal_batch_in_each_thread(self, task_batch: list[dict]):
  141. """
  142. deal in each thread
  143. """
  144. thread_db_client = DatabaseConnector(long_articles_config)
  145. thread_db_client.connect()
  146. title_batch = [(i["id"], i["article_title"]) for i in task_batch]
  147. id_tuple = tuple([int(i["id"]) for i in task_batch])
  148. lock_rows = self.lock_task(thread_db_client, id_tuple)
  149. if lock_rows:
  150. prompt = category_generation_from_title(title_batch)
  151. try:
  152. completion = fetch_deepseek_completion(
  153. model="DeepSeek-V3", prompt=prompt, output_type="json"
  154. )
  155. except Exception as e:
  156. log(
  157. task=self.const.TASK_NAME,
  158. function="category_generation_task",
  159. message=" batch 中存在敏感词,AI 拒绝返回",
  160. data={
  161. "article_id": id_tuple,
  162. "error": str(e),
  163. "traceback": traceback.format_exc(),
  164. },
  165. )
  166. for article in tqdm(task_batch):
  167. self.deal_each_article(thread_db_client, article)
  168. return
  169. for article in title_batch:
  170. self.update_title_category(thread_db_client, article[0], completion)
  171. else:
  172. return
  173. def get_task_list(self):
  174. """
  175. get task_list from a database
  176. """
  177. fetch_query = f"""
  178. select id, article_title from publish_single_video_source
  179. where category_status = %s and bad_status = %s
  180. order by score desc;
  181. """
  182. fetch_result = self.db_client.fetch(
  183. query=fetch_query,
  184. cursor_type=DictCursor,
  185. params=(self.const.INIT_STATUS, self.const.ARTICLE_GOOD_STATUS),
  186. )
  187. return fetch_result
  188. def deal(self):
  189. self.rollback_lock_tasks(self.const.VIDEO_TABLE_NAME)
  190. task_list = self.get_task_list()
  191. task_batch_list = yield_batch(data=task_list, batch_size=self.const.BATCH_SIZE)
  192. ## dev
  193. # for task_batch in task_batch_list:
  194. # self.deal_batch_in_each_thread(task_batch)
  195. with ThreadPoolExecutor(max_workers=self.const.MAX_WORKERS) as executor:
  196. futures = [
  197. executor.submit(self.deal_batch_in_each_thread, task_batch)
  198. for task_batch in task_batch_list
  199. ]
  200. for _ in tqdm(
  201. concurrent.futures.as_completed(futures),
  202. total=len(futures),
  203. desc="Processing batches",
  204. ):
  205. pass
  206. class ArticlePoolCategoryGenerationTask(CategoryGenerationTask):
  207. def set_category_status_as_success(
  208. self, thread_db_client: DatabaseConnector, article_id: int, category: str
  209. ) -> int:
  210. update_query = f"""
  211. update {self.const.ARTICLE_TABLE_NAME}
  212. set category_by_ai = %s, category_status = %s, category_status_update_ts = %s
  213. where article_id = %s and category_status = %s;
  214. """
  215. update_rows = thread_db_client.save(
  216. query=update_query,
  217. params=(
  218. category,
  219. self.const.SUCCESS_STATUS,
  220. int(time.time()),
  221. article_id,
  222. self.const.PROCESSING_STATUS,
  223. ),
  224. )
  225. return update_rows
  226. def set_category_status_as_fail(
  227. self, thread_db_client: DatabaseConnector, article_id: int
  228. ) -> int:
  229. update_query = f"""
  230. update {self.const.ARTICLE_TABLE_NAME}
  231. set category_status = %s, category_status_update_ts = %s
  232. where article_id = %s and category_status = %s;
  233. """
  234. update_rows = thread_db_client.save(
  235. query=update_query,
  236. params=(
  237. self.const.FAIL_STATUS,
  238. int(time.time()),
  239. article_id,
  240. self.const.PROCESSING_STATUS,
  241. ),
  242. )
  243. return update_rows
  244. def update_title_category(
  245. self, thread_db_client: DatabaseConnector, article_id: int, completion: dict
  246. ):
  247. try:
  248. category = completion.get(str(article_id))
  249. self.set_category_status_as_success(thread_db_client, article_id, category)
  250. except Exception as e:
  251. log(
  252. task=self.const.TASK_NAME,
  253. function="update_each_record_status",
  254. message="AI返回格式失败,更新状态为失败",
  255. data={
  256. "article_id": article_id,
  257. "error": str(e),
  258. "traceback": traceback.format_exc(),
  259. },
  260. )
  261. self.set_category_status_as_fail(thread_db_client, article_id)
  262. def lock_task(
  263. self, thread_db_client: DatabaseConnector, article_id_tuple: tuple[int, ...]
  264. ) -> int:
  265. update_query = f"""
  266. update {self.const.ARTICLE_TABLE_NAME}
  267. set category_status = %s, category_status_update_ts = %s
  268. where article_id in %s and category_status = %s;
  269. """
  270. update_rows = thread_db_client.save(
  271. query=update_query,
  272. params=(
  273. self.const.PROCESSING_STATUS,
  274. int(time.time()),
  275. article_id_tuple,
  276. self.const.INIT_STATUS,
  277. ),
  278. )
  279. return update_rows
  280. def deal_each_article(self, thread_db_client, article: dict):
  281. """
  282. deal each article
  283. """
  284. article_id = article["article_id"]
  285. title = article["title"]
  286. title_batch = [(article_id, title)]
  287. prompt = category_generation_from_title(title_batch)
  288. try:
  289. completion = fetch_deepseek_completion(
  290. model="DeepSeek-V3", prompt=prompt, output_type="json"
  291. )
  292. self.update_title_category(thread_db_client, article_id, completion)
  293. except Exception as e:
  294. log(
  295. task=self.const.TASK_NAME,
  296. message="该文章存在敏感词,AI 拒绝返回",
  297. function="deal_each_article",
  298. data={
  299. "article_id": article_id,
  300. "error": str(e),
  301. "traceback": traceback.format_exc(),
  302. },
  303. )
  304. self.set_category_status_as_fail(thread_db_client, article_id)
  305. def deal_batch_in_each_thread(self, task_batch: list[dict]):
  306. """
  307. deal in each thread
  308. """
  309. thread_db_client = DatabaseConnector(long_articles_config)
  310. thread_db_client.connect()
  311. title_batch = [(i["article_id"], i["title"]) for i in task_batch]
  312. id_tuple = tuple([int(i["article_id"]) for i in task_batch])
  313. lock_rows = self.lock_task(thread_db_client, id_tuple)
  314. if lock_rows:
  315. prompt = category_generation_from_title(title_batch)
  316. try:
  317. completion = fetch_deepseek_completion(
  318. model="DeepSeek-V3", prompt=prompt, output_type="json"
  319. )
  320. except Exception as e:
  321. log(
  322. task=self.const.TASK_NAME,
  323. function="article_category_generation_task",
  324. message="batch 中存在敏感词,AI 拒绝返回",
  325. data={
  326. "article_id": id_tuple,
  327. "error": str(e),
  328. "traceback": traceback.format_exc(),
  329. },
  330. )
  331. for article in tqdm(task_batch):
  332. self.deal_each_article(thread_db_client, article)
  333. return
  334. for article in title_batch:
  335. self.update_title_category(thread_db_client, article[0], completion)
  336. else:
  337. return
  338. def get_task_list(self):
  339. """
  340. get task_list from a database
  341. """
  342. fetch_query = f"""
  343. select article_id, title from {self.const.ARTICLE_TABLE_NAME}
  344. where category_status = %s and status = %s and score > %s
  345. order by score desc limit 1000;
  346. """
  347. fetch_result = self.db_client.fetch(
  348. query=fetch_query,
  349. cursor_type=DictCursor,
  350. params=(self.const.INIT_STATUS, self.const.ARTICLE_INIT_STATUS, self.const.LIMIT_SCORE),
  351. )
  352. return fetch_result
  353. def deal(self):
  354. self.rollback_lock_tasks(self.const.ARTICLE_TABLE_NAME)
  355. task_list = self.get_task_list()
  356. task_batch_list = yield_batch(data=task_list, batch_size=self.const.BATCH_SIZE)
  357. # # dev
  358. # for task_batch in task_batch_list:
  359. # self.deal_batch_in_each_thread(task_batch)
  360. with ThreadPoolExecutor(max_workers=self.const.MAX_WORKERS) as executor:
  361. futures = [
  362. executor.submit(self.deal_batch_in_each_thread, task_batch)
  363. for task_batch in task_batch_list
  364. ]
  365. for _ in tqdm(
  366. concurrent.futures.as_completed(futures),
  367. total=len(futures),
  368. desc="Processing batches",
  369. ):
  370. pass