|
@@ -0,0 +1,241 @@
|
|
|
+"""
|
|
|
+generate category for given title
|
|
|
+"""
|
|
|
+
|
|
|
+import time
|
|
|
+import concurrent
|
|
|
+import traceback
|
|
|
+from concurrent.futures import ThreadPoolExecutor
|
|
|
+
|
|
|
+from pymysql.cursors import DictCursor
|
|
|
+from tqdm import tqdm
|
|
|
+
|
|
|
+from applications import log
|
|
|
+from applications.api.deep_seek_api_official import fetch_deepseek_completion
|
|
|
+from applications.const import CategoryGenerationTaskConst
|
|
|
+from applications.db import DatabaseConnector
|
|
|
+from applications.utils import yield_batch
|
|
|
+
|
|
|
+from config import long_articles_config
|
|
|
+from tasks.ai_tasks.prompts import category_generation_from_title
|
|
|
+
|
|
|
+
|
|
|
+class CategoryGenerationTask:
|
|
|
+
|
|
|
+ def __init__(self):
|
|
|
+ self.db_client = DatabaseConnector(long_articles_config)
|
|
|
+ self.db_client.connect()
|
|
|
+ self.const = CategoryGenerationTaskConst()
|
|
|
+
|
|
|
+ def set_category_status_as_success(
|
|
|
+ self, thread_db_client: DatabaseConnector, article_id: int, category: str
|
|
|
+ ) -> int:
|
|
|
+
|
|
|
+ update_query = f"""
|
|
|
+ update publish_single_video_source
|
|
|
+ set category = %s, category_status = %s, category_status_update_ts = %s
|
|
|
+ where id = %s and category_status = %s;
|
|
|
+ """
|
|
|
+
|
|
|
+ update_rows = thread_db_client.save(
|
|
|
+ query=update_query,
|
|
|
+ params=(
|
|
|
+ category,
|
|
|
+ self.const.SUCCESS_STATUS,
|
|
|
+ int(time.time()),
|
|
|
+ article_id,
|
|
|
+ self.const.PROCESSING_STATUS,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ return update_rows
|
|
|
+
|
|
|
+ def set_category_status_as_fail(
|
|
|
+ self, thread_db_client: DatabaseConnector, article_id: int
|
|
|
+ ) -> int:
|
|
|
+
|
|
|
+ update_query = f"""
|
|
|
+ update publish_single_video_source
|
|
|
+ set category_status = %s, category_status_update_ts = %s
|
|
|
+ where id = %s and category_status = %s;
|
|
|
+ """
|
|
|
+
|
|
|
+ update_rows = thread_db_client.save(
|
|
|
+ query=update_query,
|
|
|
+ params=(
|
|
|
+ self.const.FAIL_STATUS,
|
|
|
+ int(time.time()),
|
|
|
+ article_id,
|
|
|
+ self.const.PROCESSING_STATUS,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ return update_rows
|
|
|
+
|
|
|
+ def update_title_category(
|
|
|
+ self, thread_db_client: DatabaseConnector, article_id: int, completion: dict
|
|
|
+ ):
|
|
|
+
|
|
|
+ try:
|
|
|
+ category = completion.get(str(article_id))
|
|
|
+ self.set_category_status_as_success(thread_db_client, article_id, category)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ log(
|
|
|
+ task=self.const.TASK_NAME,
|
|
|
+ function="update_each_record_status",
|
|
|
+ message="AI返回格式失败,更新状态为失败",
|
|
|
+ data={
|
|
|
+ "article_id": article_id,
|
|
|
+ "error": str(e),
|
|
|
+ "traceback": traceback.format_exc(),
|
|
|
+ },
|
|
|
+ )
|
|
|
+ self.set_category_status_as_fail(thread_db_client, article_id)
|
|
|
+
|
|
|
+ def rollback_lock_tasks(self) -> int:
|
|
|
+
|
|
|
+ update_query = f"""
|
|
|
+ update publish_single_video_source
|
|
|
+ set category_status = %s
|
|
|
+ where category_status = %s and category_status_update_ts <= %s;
|
|
|
+ """
|
|
|
+
|
|
|
+ update_rows = self.db_client.save(
|
|
|
+ query=update_query,
|
|
|
+ params=(
|
|
|
+ self.const.INIT_STATUS,
|
|
|
+ self.const.PROCESSING_STATUS,
|
|
|
+ int(time.time()) - self.const.MAX_PROCESSING_TIME,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ return update_rows
|
|
|
+
|
|
|
+ def lock_task(
|
|
|
+ self, thread_db_client: DatabaseConnector, article_id_tuple: tuple[int, ...]
|
|
|
+ ) -> int:
|
|
|
+
|
|
|
+ update_query = f"""
|
|
|
+ update publish_single_video_source
|
|
|
+ set category_status = %s, category_status_update_ts = %s
|
|
|
+ where id in %s and category_status = %s;
|
|
|
+ """
|
|
|
+
|
|
|
+ update_rows = thread_db_client.save(
|
|
|
+ query=update_query,
|
|
|
+ params=(
|
|
|
+ self.const.PROCESSING_STATUS,
|
|
|
+ int(time.time()),
|
|
|
+ article_id_tuple,
|
|
|
+ self.const.INIT_STATUS,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ return update_rows
|
|
|
+
|
|
|
+ def deal_each_article(self, thread_db_client, article: dict):
|
|
|
+ """
|
|
|
+ deal each article
|
|
|
+ """
|
|
|
+ article_id = article["id"]
|
|
|
+ title = article["article_title"]
|
|
|
+
|
|
|
+ title_batch = [(article_id, title)]
|
|
|
+
|
|
|
+ prompt = category_generation_from_title(title_batch)
|
|
|
+ try:
|
|
|
+ completion = fetch_deepseek_completion(
|
|
|
+ model="DeepSeek-V3", prompt=prompt, output_type="json"
|
|
|
+ )
|
|
|
+ self.update_title_category(thread_db_client, article_id, completion)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ log(
|
|
|
+ task=self.const.TASK_NAME,
|
|
|
+ message="该文章存在敏感词,AI 拒绝返回",
|
|
|
+ function="deal_each_article",
|
|
|
+ data={
|
|
|
+ "article_id": article_id,
|
|
|
+ "error": str(e),
|
|
|
+ "traceback": traceback.format_exc(),
|
|
|
+ },
|
|
|
+ )
|
|
|
+ self.set_category_status_as_fail(thread_db_client, article_id)
|
|
|
+
|
|
|
+ def deal_batch_in_each_thread(self, task_batch: list[dict]):
|
|
|
+ """
|
|
|
+ deal in each thread
|
|
|
+ """
|
|
|
+ thread_db_client = DatabaseConnector(long_articles_config)
|
|
|
+ thread_db_client.connect()
|
|
|
+
|
|
|
+ title_batch = [(i["id"], i["article_title"]) for i in task_batch]
|
|
|
+ id_tuple = tuple([int(i["id"]) for i in task_batch])
|
|
|
+
|
|
|
+ lock_rows = self.lock_task(thread_db_client, id_tuple)
|
|
|
+ if lock_rows:
|
|
|
+ prompt = category_generation_from_title(title_batch)
|
|
|
+
|
|
|
+ try:
|
|
|
+ completion = fetch_deepseek_completion(
|
|
|
+ model="DeepSeek-V3", prompt=prompt, output_type="json"
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ log(
|
|
|
+ task=self.const.TASK_NAME,
|
|
|
+ function="category_generation_task",
|
|
|
+ message=" batch 中存在敏感词,AI 拒绝返回",
|
|
|
+ data={
|
|
|
+ "article_id": id_tuple,
|
|
|
+ "error": str(e),
|
|
|
+ "traceback": traceback.format_exc(),
|
|
|
+ },
|
|
|
+ )
|
|
|
+ for article in tqdm(task_batch):
|
|
|
+ self.deal_each_article(thread_db_client, article)
|
|
|
+
|
|
|
+ return
|
|
|
+
|
|
|
+ for article in title_batch:
|
|
|
+ self.update_title_category(thread_db_client, article[0], completion)
|
|
|
+
|
|
|
+ else:
|
|
|
+ return
|
|
|
+
|
|
|
+ def get_task_list(self):
|
|
|
+ """
|
|
|
+ get task_list from a database
|
|
|
+ """
|
|
|
+ fetch_query = f"""
|
|
|
+ select id, article_title from publish_single_video_source
|
|
|
+ where category_status = %s and bad_status = %s
|
|
|
+ order by score desc;
|
|
|
+ """
|
|
|
+ fetch_result = self.db_client.fetch(
|
|
|
+ query=fetch_query,
|
|
|
+ cursor_type=DictCursor,
|
|
|
+ params=(self.const.INIT_STATUS, self.const.ARTICLE_GOOD_STATUS),
|
|
|
+ )
|
|
|
+ return fetch_result
|
|
|
+
|
|
|
+ def deal(self):
|
|
|
+
|
|
|
+ self.rollback_lock_tasks()
|
|
|
+
|
|
|
+ task_list = self.get_task_list()
|
|
|
+ task_batch_list = yield_batch(data=task_list, batch_size=self.const.BATCH_SIZE)
|
|
|
+
|
|
|
+ ## dev
|
|
|
+ # for task_batch in task_batch_list:
|
|
|
+ # self.deal_batch_in_each_thread(task_batch)
|
|
|
+
|
|
|
+ with ThreadPoolExecutor(max_workers=self.const.MAX_WORKERS) as executor:
|
|
|
+ futures = [
|
|
|
+ executor.submit(self.deal_batch_in_each_thread, task_batch)
|
|
|
+ for task_batch in task_batch_list
|
|
|
+ ]
|
|
|
+
|
|
|
+ for _ in tqdm(
|
|
|
+ concurrent.futures.as_completed(futures),
|
|
|
+ total=len(futures),
|
|
|
+ desc="Processing batches",
|
|
|
+ ):
|
|
|
+ pass
|