category_generation_task.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. """
  2. generate category for given title
  3. """
  4. import time
  5. import concurrent
  6. import traceback
  7. from concurrent.futures import ThreadPoolExecutor
  8. from pymysql.cursors import DictCursor
  9. from pandas import DataFrame
  10. from tqdm import tqdm
  11. from applications import log
  12. from applications.api.deep_seek_api_official import fetch_deepseek_completion
  13. from applications.const import CategoryGenerationTaskConst
  14. from applications.db import DatabaseConnector
  15. from applications.utils import yield_batch
  16. from config import long_articles_config
  17. from tasks.ai_tasks.prompts import category_generation_from_title
  18. class CategoryGenerationTask:
  19. def __init__(self):
  20. self.db_client = DatabaseConnector(long_articles_config)
  21. self.db_client.connect()
  22. self.const = CategoryGenerationTaskConst()
  23. def rollback_lock_tasks(self, table_name) -> int:
  24. """
  25. 回滚锁定的任务
  26. :param table_name:
  27. """
  28. update_query = f"""
  29. update {table_name}
  30. set category_status = %s
  31. where category_status = %s and category_status_update_ts <= %s;
  32. """
  33. update_rows = self.db_client.save(
  34. query=update_query,
  35. params=(
  36. self.const.INIT_STATUS,
  37. self.const.PROCESSING_STATUS,
  38. int(time.time()) - self.const.MAX_PROCESSING_TIME,
  39. ),
  40. )
  41. return update_rows
  42. class VideoPoolCategoryGenerationTask(CategoryGenerationTask):
  43. def set_category_status_as_success(
  44. self, thread_db_client: DatabaseConnector, article_id: int, category: str
  45. ) -> int:
  46. update_query = f"""
  47. update publish_single_video_source
  48. set category = %s, category_status = %s, category_status_update_ts = %s
  49. where id = %s and category_status = %s;
  50. """
  51. update_rows = thread_db_client.save(
  52. query=update_query,
  53. params=(
  54. category,
  55. self.const.SUCCESS_STATUS,
  56. int(time.time()),
  57. article_id,
  58. self.const.PROCESSING_STATUS,
  59. ),
  60. )
  61. return update_rows
  62. def set_category_status_as_fail(
  63. self, thread_db_client: DatabaseConnector, article_id: int
  64. ) -> int:
  65. update_query = f"""
  66. update publish_single_video_source
  67. set category_status = %s, category_status_update_ts = %s
  68. where id = %s and category_status = %s;
  69. """
  70. update_rows = thread_db_client.save(
  71. query=update_query,
  72. params=(
  73. self.const.FAIL_STATUS,
  74. int(time.time()),
  75. article_id,
  76. self.const.PROCESSING_STATUS,
  77. ),
  78. )
  79. return update_rows
  80. def update_title_category(
  81. self, thread_db_client: DatabaseConnector, article_id: int, completion: dict
  82. ):
  83. try:
  84. category = completion.get(str(article_id))
  85. self.set_category_status_as_success(thread_db_client, article_id, category)
  86. except Exception as e:
  87. log(
  88. task=self.const.TASK_NAME,
  89. function="update_each_record_status",
  90. message="AI返回格式失败,更新状态为失败",
  91. data={
  92. "article_id": article_id,
  93. "error": str(e),
  94. "traceback": traceback.format_exc(),
  95. },
  96. )
  97. self.set_category_status_as_fail(thread_db_client, article_id)
  98. def lock_task(
  99. self, thread_db_client: DatabaseConnector, article_id_tuple: tuple[int, ...]
  100. ) -> int:
  101. update_query = f"""
  102. update publish_single_video_source
  103. set category_status = %s, category_status_update_ts = %s
  104. where id in %s and category_status = %s;
  105. """
  106. update_rows = thread_db_client.save(
  107. query=update_query,
  108. params=(
  109. self.const.PROCESSING_STATUS,
  110. int(time.time()),
  111. article_id_tuple,
  112. self.const.INIT_STATUS,
  113. ),
  114. )
  115. return update_rows
  116. def deal_each_article(self, thread_db_client, article: dict):
  117. """
  118. deal each article
  119. """
  120. article_id = article["id"]
  121. title = article["article_title"]
  122. title_batch = [(article_id, title)]
  123. prompt = category_generation_from_title(title_batch)
  124. try:
  125. completion = fetch_deepseek_completion(
  126. model="DeepSeek-V3", prompt=prompt, output_type="json"
  127. )
  128. self.update_title_category(thread_db_client, article_id, completion)
  129. except Exception as e:
  130. log(
  131. task=self.const.TASK_NAME,
  132. message="该文章存在敏感词,AI 拒绝返回",
  133. function="deal_each_article",
  134. data={
  135. "article_id": article_id,
  136. "error": str(e),
  137. "traceback": traceback.format_exc(),
  138. },
  139. )
  140. self.set_category_status_as_fail(thread_db_client, article_id)
  141. def deal_batch_in_each_thread(self, task_batch: list[dict]):
  142. """
  143. deal in each thread
  144. """
  145. thread_db_client = DatabaseConnector(long_articles_config)
  146. thread_db_client.connect()
  147. title_batch = [(i["id"], i["article_title"]) for i in task_batch]
  148. id_tuple = tuple([int(i["id"]) for i in task_batch])
  149. lock_rows = self.lock_task(thread_db_client, id_tuple)
  150. if lock_rows:
  151. prompt = category_generation_from_title(title_batch)
  152. try:
  153. completion = fetch_deepseek_completion(
  154. model="DeepSeek-V3", prompt=prompt, output_type="json"
  155. )
  156. except Exception as e:
  157. log(
  158. task=self.const.TASK_NAME,
  159. function="category_generation_task",
  160. message=" batch 中存在敏感词,AI 拒绝返回",
  161. data={
  162. "article_id": id_tuple,
  163. "error": str(e),
  164. "traceback": traceback.format_exc(),
  165. },
  166. )
  167. for article in tqdm(task_batch):
  168. self.deal_each_article(thread_db_client, article)
  169. return
  170. for article in title_batch:
  171. self.update_title_category(thread_db_client, article[0], completion)
  172. else:
  173. return
  174. def get_task_list(self):
  175. """
  176. get task_list from a database
  177. """
  178. fetch_query = f"""
  179. select id, article_title from publish_single_video_source
  180. where category_status = %s and bad_status = %s
  181. order by score desc;
  182. """
  183. fetch_result = self.db_client.fetch(
  184. query=fetch_query,
  185. cursor_type=DictCursor,
  186. params=(self.const.INIT_STATUS, self.const.ARTICLE_GOOD_STATUS),
  187. )
  188. return fetch_result
  189. def deal(self):
  190. self.rollback_lock_tasks(self.const.VIDEO_TABLE_NAME)
  191. task_list = self.get_task_list()
  192. task_batch_list = yield_batch(data=task_list, batch_size=self.const.BATCH_SIZE)
  193. ## dev
  194. # for task_batch in task_batch_list:
  195. # self.deal_batch_in_each_thread(task_batch)
  196. with ThreadPoolExecutor(max_workers=self.const.MAX_WORKERS) as executor:
  197. futures = [
  198. executor.submit(self.deal_batch_in_each_thread, task_batch)
  199. for task_batch in task_batch_list
  200. ]
  201. for _ in tqdm(
  202. concurrent.futures.as_completed(futures),
  203. total=len(futures),
  204. desc="Processing batches",
  205. ):
  206. pass
  207. class ArticlePoolCategoryGenerationTask(CategoryGenerationTask):
  208. def set_category_status_as_success(
  209. self, thread_db_client: DatabaseConnector, article_id: int, category: str
  210. ) -> int:
  211. update_query = f"""
  212. update {self.const.ARTICLE_TABLE_NAME}
  213. set category_by_ai = %s, category_status = %s, category_status_update_ts = %s
  214. where article_id = %s and category_status = %s;
  215. """
  216. update_rows = thread_db_client.save(
  217. query=update_query,
  218. params=(
  219. category,
  220. self.const.SUCCESS_STATUS,
  221. int(time.time()),
  222. article_id,
  223. self.const.PROCESSING_STATUS,
  224. ),
  225. )
  226. return update_rows
  227. def set_category_status_as_fail(
  228. self, thread_db_client: DatabaseConnector, article_id: int
  229. ) -> int:
  230. update_query = f"""
  231. update {self.const.ARTICLE_TABLE_NAME}
  232. set category_status = %s, category_status_update_ts = %s
  233. where article_id = %s and category_status = %s;
  234. """
  235. update_rows = thread_db_client.save(
  236. query=update_query,
  237. params=(
  238. self.const.FAIL_STATUS,
  239. int(time.time()),
  240. article_id,
  241. self.const.PROCESSING_STATUS,
  242. ),
  243. )
  244. return update_rows
  245. def update_title_category(
  246. self, thread_db_client: DatabaseConnector, article_id: int, completion: dict
  247. ):
  248. try:
  249. category = completion.get(str(article_id))
  250. self.set_category_status_as_success(thread_db_client, article_id, category)
  251. except Exception as e:
  252. log(
  253. task=self.const.TASK_NAME,
  254. function="update_each_record_status",
  255. message="AI返回格式失败,更新状态为失败",
  256. data={
  257. "article_id": article_id,
  258. "error": str(e),
  259. "traceback": traceback.format_exc(),
  260. },
  261. )
  262. self.set_category_status_as_fail(thread_db_client, article_id)
  263. def lock_task(
  264. self, thread_db_client: DatabaseConnector, article_id_tuple: tuple[int, ...]
  265. ) -> int:
  266. update_query = f"""
  267. update {self.const.ARTICLE_TABLE_NAME}
  268. set category_status = %s, category_status_update_ts = %s
  269. where article_id in %s and category_status = %s;
  270. """
  271. update_rows = thread_db_client.save(
  272. query=update_query,
  273. params=(
  274. self.const.PROCESSING_STATUS,
  275. int(time.time()),
  276. article_id_tuple,
  277. self.const.INIT_STATUS,
  278. ),
  279. )
  280. return update_rows
  281. def deal_each_article(self, thread_db_client, article: dict):
  282. """
  283. deal each article
  284. """
  285. article_id = article["article_id"]
  286. title = article["title"]
  287. title_batch = [(article_id, title)]
  288. prompt = category_generation_from_title(title_batch)
  289. try:
  290. completion = fetch_deepseek_completion(
  291. model="DeepSeek-V3", prompt=prompt, output_type="json"
  292. )
  293. self.update_title_category(thread_db_client, article_id, completion)
  294. except Exception as e:
  295. log(
  296. task=self.const.TASK_NAME,
  297. message="该文章存在敏感词,AI 拒绝返回",
  298. function="deal_each_article",
  299. data={
  300. "article_id": article_id,
  301. "error": str(e),
  302. "traceback": traceback.format_exc(),
  303. },
  304. )
  305. self.set_category_status_as_fail(thread_db_client, article_id)
  306. def deal_batch_in_each_thread(self, task_batch: list[dict]):
  307. """
  308. deal in each thread
  309. """
  310. thread_db_client = DatabaseConnector(long_articles_config)
  311. thread_db_client.connect()
  312. title_batch = [(i["article_id"], i["title"]) for i in task_batch]
  313. id_tuple = tuple([int(i["article_id"]) for i in task_batch])
  314. lock_rows = self.lock_task(thread_db_client, id_tuple)
  315. if lock_rows:
  316. prompt = category_generation_from_title(title_batch)
  317. try:
  318. completion = fetch_deepseek_completion(
  319. model="DeepSeek-V3", prompt=prompt, output_type="json"
  320. )
  321. except Exception as e:
  322. log(
  323. task=self.const.TASK_NAME,
  324. function="article_category_generation_task",
  325. message="batch 中存在敏感词,AI 拒绝返回",
  326. data={
  327. "article_id": id_tuple,
  328. "error": str(e),
  329. "traceback": traceback.format_exc(),
  330. },
  331. )
  332. for article in tqdm(task_batch):
  333. self.deal_each_article(thread_db_client, article)
  334. return
  335. for article in title_batch:
  336. self.update_title_category(thread_db_client, article[0], completion)
  337. else:
  338. return
  339. def get_task_list(self):
  340. """
  341. get task_list from a database
  342. """
  343. fetch_query = f"""
  344. select article_id, title from {self.const.ARTICLE_TABLE_NAME}
  345. where category_status = %s and status = %s and score > %s and read_cnt >= 5000
  346. order by score desc limit 100000;
  347. """
  348. fetch_result = self.db_client.fetch(
  349. query=fetch_query,
  350. cursor_type=DictCursor,
  351. params=(self.const.INIT_STATUS, self.const.ARTICLE_INIT_STATUS, self.const.LIMIT_SCORE),
  352. )
  353. return fetch_result
  354. def get_task_v2(self):
  355. fetch_query = f"""
  356. select
  357. article_id, out_account_id, article_index, title, read_cnt, status, score
  358. from
  359. crawler_meta_article
  360. where
  361. category = 'account_association' and title_sensitivity = 0 and platform = 'weixin'
  362. order by score desc
  363. """
  364. article_list = self.db_client.fetch(query=fetch_query)
  365. articles_df = DataFrame(
  366. article_list,
  367. columns=['article_id', 'gh_id', 'position', 'title', 'read_cnt', 'status','score']
  368. )
  369. # filter
  370. articles_df['average_read'] = articles_df.groupby(['gh_id', 'position'])['read_cnt'].transform('mean')
  371. articles_df['read_times'] = articles_df['read_cnt'] / articles_df['average_read']
  372. # 第0层过滤已经发布的文章
  373. filter_df = articles_df[articles_df['status'] == 1]
  374. # 第一层漏斗通过阅读均值倍数过滤
  375. filter_df = filter_df[filter_df['read_times'] >= 1.3]
  376. # 第二层漏斗通过阅读量过滤
  377. filter_df = filter_df[
  378. filter_df['read_cnt'] >= 5000
  379. ]
  380. # 第三层漏斗通过标题长度过滤
  381. filter_df = filter_df[
  382. (filter_df['title'].str.len() >= 15)
  383. & (filter_df['title'].str.len() <= 50)
  384. ]
  385. # 第四层通过敏感词过滤
  386. filter_df = filter_df[
  387. (~filter_df['title'].str.contains('农历'))
  388. & (~filter_df['title'].str.contains('太极'))
  389. & (~filter_df['title'].str.contains('节'))
  390. & (~filter_df['title'].str.contains('早上好'))
  391. & (~filter_df['title'].str.contains('赖清德'))
  392. & (~filter_df['title'].str.contains('普京'))
  393. & (~filter_df['title'].str.contains('俄'))
  394. & (~filter_df['title'].str.contains('南海'))
  395. & (~filter_df['title'].str.contains('台海'))
  396. & (~filter_df['title'].str.contains('解放军'))
  397. & (~filter_df['title'].str.contains('蔡英文'))
  398. & (~filter_df['title'].str.contains('中国'))
  399. ]
  400. length_level4 = filter_df.shape[0]
  401. # 第六层通过相关性分数过滤
  402. filter_df = filter_df[filter_df['score'] > 0.4]
  403. result = filter_df[['article_id', 'title']].to_dict(orient='records')
  404. return result
  405. def deal(self):
  406. self.rollback_lock_tasks(self.const.ARTICLE_TABLE_NAME)
  407. task_list = self.get_task_list()
  408. task_batch_list = yield_batch(data=task_list, batch_size=self.const.BATCH_SIZE)
  409. # # dev
  410. # for task_batch in task_batch_list:
  411. # self.deal_batch_in_each_thread(task_batch)
  412. with ThreadPoolExecutor(max_workers=self.const.MAX_WORKERS) as executor:
  413. futures = [
  414. executor.submit(self.deal_batch_in_each_thread, task_batch)
  415. for task_batch in task_batch_list
  416. ]
  417. for _ in tqdm(
  418. concurrent.futures.as_completed(futures),
  419. total=len(futures),
  420. desc="Processing batches",
  421. ):
  422. pass