123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241 |
- """
- generate category for given title
- """
- import time
- import concurrent
- import traceback
- from concurrent.futures import ThreadPoolExecutor
- from pymysql.cursors import DictCursor
- from tqdm import tqdm
- from applications import log
- from applications.api.deep_seek_api_official import fetch_deepseek_completion
- from applications.const import CategoryGenerationTaskConst
- from applications.db import DatabaseConnector
- from applications.utils import yield_batch
- from config import long_articles_config
- from tasks.ai_tasks.prompts import category_generation_from_title
- class CategoryGenerationTask:
- def __init__(self):
- self.db_client = DatabaseConnector(long_articles_config)
- self.db_client.connect()
- self.const = CategoryGenerationTaskConst()
- def set_category_status_as_success(
- self, thread_db_client: DatabaseConnector, article_id: int, category: str
- ) -> int:
- update_query = f"""
- update publish_single_video_source
- set category = %s, category_status = %s, category_status_update_ts = %s
- where id = %s and category_status = %s;
- """
- update_rows = thread_db_client.save(
- query=update_query,
- params=(
- category,
- self.const.SUCCESS_STATUS,
- int(time.time()),
- article_id,
- self.const.PROCESSING_STATUS,
- ),
- )
- return update_rows
- def set_category_status_as_fail(
- self, thread_db_client: DatabaseConnector, article_id: int
- ) -> int:
- update_query = f"""
- update publish_single_video_source
- set category_status = %s, category_status_update_ts = %s
- where id = %s and category_status = %s;
- """
- update_rows = thread_db_client.save(
- query=update_query,
- params=(
- self.const.FAIL_STATUS,
- int(time.time()),
- article_id,
- self.const.PROCESSING_STATUS,
- ),
- )
- return update_rows
- def update_title_category(
- self, thread_db_client: DatabaseConnector, article_id: int, completion: dict
- ):
- try:
- category = completion.get(str(article_id))
- self.set_category_status_as_success(thread_db_client, article_id, category)
- except Exception as e:
- log(
- task=self.const.TASK_NAME,
- function="update_each_record_status",
- message="AI返回格式失败,更新状态为失败",
- data={
- "article_id": article_id,
- "error": str(e),
- "traceback": traceback.format_exc(),
- },
- )
- self.set_category_status_as_fail(thread_db_client, article_id)
- def rollback_lock_tasks(self) -> int:
- update_query = f"""
- update publish_single_video_source
- set category_status = %s
- where category_status = %s and category_status_update_ts <= %s;
- """
- update_rows = self.db_client.save(
- query=update_query,
- params=(
- self.const.INIT_STATUS,
- self.const.PROCESSING_STATUS,
- int(time.time()) - self.const.MAX_PROCESSING_TIME,
- ),
- )
- return update_rows
- def lock_task(
- self, thread_db_client: DatabaseConnector, article_id_tuple: tuple[int, ...]
- ) -> int:
- update_query = f"""
- update publish_single_video_source
- set category_status = %s, category_status_update_ts = %s
- where id in %s and category_status = %s;
- """
- update_rows = thread_db_client.save(
- query=update_query,
- params=(
- self.const.PROCESSING_STATUS,
- int(time.time()),
- article_id_tuple,
- self.const.INIT_STATUS,
- ),
- )
- return update_rows
- def deal_each_article(self, thread_db_client, article: dict):
- """
- deal each article
- """
- article_id = article["id"]
- title = article["article_title"]
- title_batch = [(article_id, title)]
- prompt = category_generation_from_title(title_batch)
- try:
- completion = fetch_deepseek_completion(
- model="DeepSeek-V3", prompt=prompt, output_type="json"
- )
- self.update_title_category(thread_db_client, article_id, completion)
- except Exception as e:
- log(
- task=self.const.TASK_NAME,
- message="该文章存在敏感词,AI 拒绝返回",
- function="deal_each_article",
- data={
- "article_id": article_id,
- "error": str(e),
- "traceback": traceback.format_exc(),
- },
- )
- self.set_category_status_as_fail(thread_db_client, article_id)
- def deal_batch_in_each_thread(self, task_batch: list[dict]):
- """
- deal in each thread
- """
- thread_db_client = DatabaseConnector(long_articles_config)
- thread_db_client.connect()
- title_batch = [(i["id"], i["article_title"]) for i in task_batch]
- id_tuple = tuple([int(i["id"]) for i in task_batch])
- lock_rows = self.lock_task(thread_db_client, id_tuple)
- if lock_rows:
- prompt = category_generation_from_title(title_batch)
- try:
- completion = fetch_deepseek_completion(
- model="DeepSeek-V3", prompt=prompt, output_type="json"
- )
- except Exception as e:
- log(
- task=self.const.TASK_NAME,
- function="category_generation_task",
- message=" batch 中存在敏感词,AI 拒绝返回",
- data={
- "article_id": id_tuple,
- "error": str(e),
- "traceback": traceback.format_exc(),
- },
- )
- for article in tqdm(task_batch):
- self.deal_each_article(thread_db_client, article)
- return
- for article in title_batch:
- self.update_title_category(thread_db_client, article[0], completion)
- else:
- return
- def get_task_list(self):
- """
- get task_list from a database
- """
- fetch_query = f"""
- select id, article_title from publish_single_video_source
- where category_status = %s and bad_status = %s
- order by score desc;
- """
- fetch_result = self.db_client.fetch(
- query=fetch_query,
- cursor_type=DictCursor,
- params=(self.const.INIT_STATUS, self.const.ARTICLE_GOOD_STATUS),
- )
- return fetch_result
- def deal(self):
- self.rollback_lock_tasks()
- task_list = self.get_task_list()
- task_batch_list = yield_batch(data=task_list, batch_size=self.const.BATCH_SIZE)
- ## dev
- # for task_batch in task_batch_list:
- # self.deal_batch_in_each_thread(task_batch)
- with ThreadPoolExecutor(max_workers=self.const.MAX_WORKERS) as executor:
- futures = [
- executor.submit(self.deal_batch_in_each_thread, task_batch)
- for task_batch in task_batch_list
- ]
- for _ in tqdm(
- concurrent.futures.as_completed(futures),
- total=len(futures),
- desc="Processing batches",
- ):
- pass
|