""" generate category for given title """ import time import concurrent import traceback from concurrent.futures import ThreadPoolExecutor from pymysql.cursors import DictCursor from pandas import DataFrame from tqdm import tqdm from applications import log from applications.api.deep_seek_api_official import fetch_deepseek_completion from applications.const import CategoryGenerationTaskConst from applications.db import DatabaseConnector from applications.utils import yield_batch from config import long_articles_config from tasks.ai_tasks.prompts import category_generation_from_title class CategoryGenerationTask: def __init__(self): self.db_client = DatabaseConnector(long_articles_config) self.db_client.connect() self.const = CategoryGenerationTaskConst() def rollback_lock_tasks(self, table_name) -> int: """ 回滚锁定的任务 :param table_name: """ update_query = f""" update {table_name} set category_status = %s where category_status = %s and category_status_update_ts <= %s; """ update_rows = self.db_client.save( query=update_query, params=( self.const.INIT_STATUS, self.const.PROCESSING_STATUS, int(time.time()) - self.const.MAX_PROCESSING_TIME, ), ) return update_rows class VideoPoolCategoryGenerationTask(CategoryGenerationTask): def set_category_status_as_success( self, thread_db_client: DatabaseConnector, article_id: int, category: str ) -> int: update_query = f""" update publish_single_video_source set category = %s, category_status = %s, category_status_update_ts = %s where id = %s and category_status = %s; """ update_rows = thread_db_client.save( query=update_query, params=( category, self.const.SUCCESS_STATUS, int(time.time()), article_id, self.const.PROCESSING_STATUS, ), ) return update_rows def set_category_status_as_fail( self, thread_db_client: DatabaseConnector, article_id: int ) -> int: update_query = f""" update publish_single_video_source set category_status = %s, category_status_update_ts = %s where id = %s and category_status = %s; """ update_rows = thread_db_client.save( query=update_query, params=( self.const.FAIL_STATUS, int(time.time()), article_id, self.const.PROCESSING_STATUS, ), ) return update_rows def update_title_category( self, thread_db_client: DatabaseConnector, article_id: int, completion: dict ): try: category = completion.get(str(article_id)) self.set_category_status_as_success(thread_db_client, article_id, category) except Exception as e: log( task=self.const.TASK_NAME, function="update_each_record_status", message="AI返回格式失败,更新状态为失败", data={ "article_id": article_id, "error": str(e), "traceback": traceback.format_exc(), }, ) self.set_category_status_as_fail(thread_db_client, article_id) def lock_task( self, thread_db_client: DatabaseConnector, article_id_tuple: tuple[int, ...] ) -> int: update_query = f""" update publish_single_video_source set category_status = %s, category_status_update_ts = %s where id in %s and category_status = %s; """ update_rows = thread_db_client.save( query=update_query, params=( self.const.PROCESSING_STATUS, int(time.time()), article_id_tuple, self.const.INIT_STATUS, ), ) return update_rows def deal_each_article(self, thread_db_client, article: dict): """ deal each article """ article_id = article["id"] title = article["article_title"] title_batch = [(article_id, title)] prompt = category_generation_from_title(title_batch) try: completion = fetch_deepseek_completion( model="DeepSeek-V3", prompt=prompt, output_type="json" ) self.update_title_category(thread_db_client, article_id, completion) except Exception as e: log( task=self.const.TASK_NAME, message="该文章存在敏感词,AI 拒绝返回", function="deal_each_article", data={ "article_id": article_id, "error": str(e), "traceback": traceback.format_exc(), }, ) self.set_category_status_as_fail(thread_db_client, article_id) def deal_batch_in_each_thread(self, task_batch: list[dict]): """ deal in each thread """ thread_db_client = DatabaseConnector(long_articles_config) thread_db_client.connect() title_batch = [(i["id"], i["article_title"]) for i in task_batch] id_tuple = tuple([int(i["id"]) for i in task_batch]) lock_rows = self.lock_task(thread_db_client, id_tuple) if lock_rows: prompt = category_generation_from_title(title_batch) try: completion = fetch_deepseek_completion( model="DeepSeek-V3", prompt=prompt, output_type="json" ) except Exception as e: log( task=self.const.TASK_NAME, function="category_generation_task", message=" batch 中存在敏感词,AI 拒绝返回", data={ "article_id": id_tuple, "error": str(e), "traceback": traceback.format_exc(), }, ) for article in tqdm(task_batch): self.deal_each_article(thread_db_client, article) return for article in title_batch: self.update_title_category(thread_db_client, article[0], completion) else: return def get_task_list(self): """ get task_list from a database """ fetch_query = f""" select id, article_title from publish_single_video_source where category_status = %s and bad_status = %s order by score desc; """ fetch_result = self.db_client.fetch( query=fetch_query, cursor_type=DictCursor, params=(self.const.INIT_STATUS, self.const.ARTICLE_GOOD_STATUS), ) return fetch_result def deal(self): self.rollback_lock_tasks(self.const.VIDEO_TABLE_NAME) task_list = self.get_task_list() task_batch_list = yield_batch(data=task_list, batch_size=self.const.BATCH_SIZE) ## dev # for task_batch in task_batch_list: # self.deal_batch_in_each_thread(task_batch) with ThreadPoolExecutor(max_workers=self.const.MAX_WORKERS) as executor: futures = [ executor.submit(self.deal_batch_in_each_thread, task_batch) for task_batch in task_batch_list ] for _ in tqdm( concurrent.futures.as_completed(futures), total=len(futures), desc="Processing batches", ): pass class ArticlePoolCategoryGenerationTask(CategoryGenerationTask): def set_category_status_as_success( self, thread_db_client: DatabaseConnector, article_id: int, category: str ) -> int: update_query = f""" update {self.const.ARTICLE_TABLE_NAME} set category_by_ai = %s, category_status = %s, category_status_update_ts = %s where article_id = %s and category_status = %s; """ update_rows = thread_db_client.save( query=update_query, params=( category, self.const.SUCCESS_STATUS, int(time.time()), article_id, self.const.PROCESSING_STATUS, ), ) return update_rows def set_category_status_as_fail( self, thread_db_client: DatabaseConnector, article_id: int ) -> int: update_query = f""" update {self.const.ARTICLE_TABLE_NAME} set category_status = %s, category_status_update_ts = %s where article_id = %s and category_status = %s; """ update_rows = thread_db_client.save( query=update_query, params=( self.const.FAIL_STATUS, int(time.time()), article_id, self.const.PROCESSING_STATUS, ), ) return update_rows def update_title_category( self, thread_db_client: DatabaseConnector, article_id: int, completion: dict ): try: category = completion.get(str(article_id)) self.set_category_status_as_success(thread_db_client, article_id, category) except Exception as e: log( task=self.const.TASK_NAME, function="update_each_record_status", message="AI返回格式失败,更新状态为失败", data={ "article_id": article_id, "error": str(e), "traceback": traceback.format_exc(), }, ) self.set_category_status_as_fail(thread_db_client, article_id) def lock_task( self, thread_db_client: DatabaseConnector, article_id_tuple: tuple[int, ...] ) -> int: update_query = f""" update {self.const.ARTICLE_TABLE_NAME} set category_status = %s, category_status_update_ts = %s where article_id in %s and category_status = %s; """ update_rows = thread_db_client.save( query=update_query, params=( self.const.PROCESSING_STATUS, int(time.time()), article_id_tuple, self.const.INIT_STATUS, ), ) return update_rows def deal_each_article(self, thread_db_client, article: dict): """ deal each article """ article_id = article["article_id"] title = article["title"] title_batch = [(article_id, title)] prompt = category_generation_from_title(title_batch) try: completion = fetch_deepseek_completion( model="DeepSeek-V3", prompt=prompt, output_type="json" ) self.update_title_category(thread_db_client, article_id, completion) except Exception as e: log( task=self.const.TASK_NAME, message="该文章存在敏感词,AI 拒绝返回", function="deal_each_article", data={ "article_id": article_id, "error": str(e), "traceback": traceback.format_exc(), }, ) self.set_category_status_as_fail(thread_db_client, article_id) def deal_batch_in_each_thread(self, task_batch: list[dict]): """ deal in each thread """ thread_db_client = DatabaseConnector(long_articles_config) thread_db_client.connect() title_batch = [(i["article_id"], i["title"]) for i in task_batch] id_tuple = tuple([int(i["article_id"]) for i in task_batch]) lock_rows = self.lock_task(thread_db_client, id_tuple) if lock_rows: prompt = category_generation_from_title(title_batch) try: completion = fetch_deepseek_completion( model="DeepSeek-V3", prompt=prompt, output_type="json" ) except Exception as e: log( task=self.const.TASK_NAME, function="article_category_generation_task", message="batch 中存在敏感词,AI 拒绝返回", data={ "article_id": id_tuple, "error": str(e), "traceback": traceback.format_exc(), }, ) for article in tqdm(task_batch): self.deal_each_article(thread_db_client, article) return for article in title_batch: self.update_title_category(thread_db_client, article[0], completion) else: return def get_task_list(self): """ get task_list from a database """ fetch_query = f""" select article_id, title from {self.const.ARTICLE_TABLE_NAME} where category_status = %s and status = %s and score > %s order by score desc limit 1000; """ fetch_result = self.db_client.fetch( query=fetch_query, cursor_type=DictCursor, params=(self.const.INIT_STATUS, self.const.ARTICLE_INIT_STATUS, self.const.LIMIT_SCORE), ) return fetch_result def get_task_v2(self): fetch_query = f""" select article_id, out_account_id, article_index, title, read_cnt, status, score from crawler_meta_article where category = 'account_association' and title_sensitivity = 0 and platform = 'weixin' order by score desc """ article_list = self.db_client.fetch(query=fetch_query) articles_df = DataFrame( article_list, columns=['article_id', 'gh_id', 'position', 'title', 'read_cnt', 'status','score'] ) # filter articles_df['average_read'] = articles_df.groupby(['gh_id', 'position'])['read_cnt'].transform('mean') articles_df['read_times'] = articles_df['read_cnt'] / articles_df['average_read'] # 第0层过滤已经发布的文章 filter_df = articles_df[articles_df['status'] == 1] # 第一层漏斗通过阅读均值倍数过滤 filter_df = filter_df[filter_df['read_times'] >= 1.3] # 第二层漏斗通过阅读量过滤 filter_df = filter_df[ filter_df['read_cnt'] >= 5000 ] # 第三层漏斗通过标题长度过滤 filter_df = filter_df[ (filter_df['title'].str.len() >= 15) & (filter_df['title'].str.len() <= 50) ] # 第四层通过敏感词过滤 filter_df = filter_df[ (~filter_df['title'].str.contains('农历')) & (~filter_df['title'].str.contains('太极')) & (~filter_df['title'].str.contains('节')) & (~filter_df['title'].str.contains('早上好')) & (~filter_df['title'].str.contains('赖清德')) & (~filter_df['title'].str.contains('普京')) & (~filter_df['title'].str.contains('俄')) & (~filter_df['title'].str.contains('南海')) & (~filter_df['title'].str.contains('台海')) & (~filter_df['title'].str.contains('解放军')) & (~filter_df['title'].str.contains('蔡英文')) & (~filter_df['title'].str.contains('中国')) ] length_level4 = filter_df.shape[0] # 第六层通过相关性分数过滤 filter_df = filter_df[filter_df['score'] > 0.4] result = filter_df[['article_id', 'title']].to_dict(orient='records') return result def deal(self): self.rollback_lock_tasks(self.const.ARTICLE_TABLE_NAME) task_list = self.get_task_v2() task_batch_list = yield_batch(data=task_list, batch_size=self.const.BATCH_SIZE) # # dev # for task_batch in task_batch_list: # self.deal_batch_in_each_thread(task_batch) with ThreadPoolExecutor(max_workers=self.const.MAX_WORKERS) as executor: futures = [ executor.submit(self.deal_batch_in_each_thread, task_batch) for task_batch in task_batch_list ] for _ in tqdm( concurrent.futures.as_completed(futures), total=len(futures), desc="Processing batches", ): pass