|
@@ -27,6 +27,31 @@ class CategoryGenerationTask:
|
|
|
self.db_client.connect()
|
|
|
self.const = CategoryGenerationTaskConst()
|
|
|
|
|
|
+ def rollback_lock_tasks(self, table_name) -> int:
|
|
|
+ """
|
|
|
+ 回滚锁定的任务
|
|
|
+ :param table_name:
|
|
|
+ """
|
|
|
+ update_query = f"""
|
|
|
+ update {table_name}
|
|
|
+ set category_status = %s
|
|
|
+ where category_status = %s and category_status_update_ts <= %s;
|
|
|
+ """
|
|
|
+
|
|
|
+ update_rows = self.db_client.save(
|
|
|
+ query=update_query,
|
|
|
+ params=(
|
|
|
+ self.const.INIT_STATUS,
|
|
|
+ self.const.PROCESSING_STATUS,
|
|
|
+ int(time.time()) - self.const.MAX_PROCESSING_TIME,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ return update_rows
|
|
|
+
|
|
|
+
|
|
|
+class VideoPoolCategoryGenerationTask(CategoryGenerationTask):
|
|
|
+
|
|
|
def set_category_status_as_success(
|
|
|
self, thread_db_client: DatabaseConnector, article_id: int, category: str
|
|
|
) -> int:
|
|
@@ -91,25 +116,6 @@ class CategoryGenerationTask:
|
|
|
)
|
|
|
self.set_category_status_as_fail(thread_db_client, article_id)
|
|
|
|
|
|
- def rollback_lock_tasks(self) -> int:
|
|
|
-
|
|
|
- update_query = f"""
|
|
|
- update publish_single_video_source
|
|
|
- set category_status = %s
|
|
|
- where category_status = %s and category_status_update_ts <= %s;
|
|
|
- """
|
|
|
-
|
|
|
- update_rows = self.db_client.save(
|
|
|
- query=update_query,
|
|
|
- params=(
|
|
|
- self.const.INIT_STATUS,
|
|
|
- self.const.PROCESSING_STATUS,
|
|
|
- int(time.time()) - self.const.MAX_PROCESSING_TIME,
|
|
|
- ),
|
|
|
- )
|
|
|
-
|
|
|
- return update_rows
|
|
|
-
|
|
|
def lock_task(
|
|
|
self, thread_db_client: DatabaseConnector, article_id_tuple: tuple[int, ...]
|
|
|
) -> int:
|
|
@@ -218,7 +224,7 @@ class CategoryGenerationTask:
|
|
|
|
|
|
def deal(self):
|
|
|
|
|
|
- self.rollback_lock_tasks()
|
|
|
+ self.rollback_lock_tasks(self.const.VIDEO_TABLE_NAME)
|
|
|
|
|
|
task_list = self.get_task_list()
|
|
|
task_batch_list = yield_batch(data=task_list, batch_size=self.const.BATCH_SIZE)
|
|
@@ -239,3 +245,200 @@ class CategoryGenerationTask:
|
|
|
desc="Processing batches",
|
|
|
):
|
|
|
pass
|
|
|
+
|
|
|
+
|
|
|
+class ArticlePoolCategoryGenerationTask(CategoryGenerationTask):
|
|
|
+ def set_category_status_as_success(
|
|
|
+ self, thread_db_client: DatabaseConnector, article_id: int, category: str
|
|
|
+ ) -> int:
|
|
|
+
|
|
|
+ update_query = f"""
|
|
|
+ update {self.const.ARTICLE_TABLE_NAME}
|
|
|
+ set category_by_ai = %s, category_status = %s, category_status_update_ts = %s
|
|
|
+ where article_id = %s and category_status = %s;
|
|
|
+ """
|
|
|
+
|
|
|
+ update_rows = thread_db_client.save(
|
|
|
+ query=update_query,
|
|
|
+ params=(
|
|
|
+ category,
|
|
|
+ self.const.SUCCESS_STATUS,
|
|
|
+ int(time.time()),
|
|
|
+ article_id,
|
|
|
+ self.const.PROCESSING_STATUS,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ return update_rows
|
|
|
+
|
|
|
+ def set_category_status_as_fail(
|
|
|
+ self, thread_db_client: DatabaseConnector, article_id: int
|
|
|
+ ) -> int:
|
|
|
+
|
|
|
+ update_query = f"""
|
|
|
+ update {self.const.ARTICLE_TABLE_NAME}
|
|
|
+ set category_status = %s, category_status_update_ts = %s
|
|
|
+ where article_id = %s and category_status = %s;
|
|
|
+ """
|
|
|
+
|
|
|
+ update_rows = thread_db_client.save(
|
|
|
+ query=update_query,
|
|
|
+ params=(
|
|
|
+ self.const.FAIL_STATUS,
|
|
|
+ int(time.time()),
|
|
|
+ article_id,
|
|
|
+ self.const.PROCESSING_STATUS,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ return update_rows
|
|
|
+
|
|
|
+ def update_title_category(
|
|
|
+ self, thread_db_client: DatabaseConnector, article_id: int, completion: dict
|
|
|
+ ):
|
|
|
+
|
|
|
+ try:
|
|
|
+ category = completion.get(str(article_id))
|
|
|
+ self.set_category_status_as_success(thread_db_client, article_id, category)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ log(
|
|
|
+ task=self.const.TASK_NAME,
|
|
|
+ function="update_each_record_status",
|
|
|
+ message="AI返回格式失败,更新状态为失败",
|
|
|
+ data={
|
|
|
+ "article_id": article_id,
|
|
|
+ "error": str(e),
|
|
|
+ "traceback": traceback.format_exc(),
|
|
|
+ },
|
|
|
+ )
|
|
|
+ self.set_category_status_as_fail(thread_db_client, article_id)
|
|
|
+
|
|
|
+
|
|
|
+ def lock_task(
|
|
|
+ self, thread_db_client: DatabaseConnector, article_id_tuple: tuple[int, ...]
|
|
|
+ ) -> int:
|
|
|
+
|
|
|
+ update_query = f"""
|
|
|
+ update {self.const.ARTICLE_TABLE_NAME}
|
|
|
+ set category_status = %s, category_status_update_ts = %s
|
|
|
+ where article_id in %s and category_status = %s;
|
|
|
+ """
|
|
|
+
|
|
|
+ update_rows = thread_db_client.save(
|
|
|
+ query=update_query,
|
|
|
+ params=(
|
|
|
+ self.const.PROCESSING_STATUS,
|
|
|
+ int(time.time()),
|
|
|
+ article_id_tuple,
|
|
|
+ self.const.INIT_STATUS,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ return update_rows
|
|
|
+
|
|
|
+ def deal_each_article(self, thread_db_client, article: dict):
|
|
|
+ """
|
|
|
+ deal each article
|
|
|
+ """
|
|
|
+ article_id = article["article_id"]
|
|
|
+ title = article["title"]
|
|
|
+
|
|
|
+ title_batch = [(article_id, title)]
|
|
|
+
|
|
|
+ prompt = category_generation_from_title(title_batch)
|
|
|
+ try:
|
|
|
+ completion = fetch_deepseek_completion(
|
|
|
+ model="DeepSeek-V3", prompt=prompt, output_type="json"
|
|
|
+ )
|
|
|
+ self.update_title_category(thread_db_client, article_id, completion)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ log(
|
|
|
+ task=self.const.TASK_NAME,
|
|
|
+ message="该文章存在敏感词,AI 拒绝返回",
|
|
|
+ function="deal_each_article",
|
|
|
+ data={
|
|
|
+ "article_id": article_id,
|
|
|
+ "error": str(e),
|
|
|
+ "traceback": traceback.format_exc(),
|
|
|
+ },
|
|
|
+ )
|
|
|
+ self.set_category_status_as_fail(thread_db_client, article_id)
|
|
|
+
|
|
|
+ def deal_batch_in_each_thread(self, task_batch: list[dict]):
|
|
|
+ """
|
|
|
+ deal in each thread
|
|
|
+ """
|
|
|
+ thread_db_client = DatabaseConnector(long_articles_config)
|
|
|
+ thread_db_client.connect()
|
|
|
+
|
|
|
+ title_batch = [(i["article_id"], i["title"]) for i in task_batch]
|
|
|
+ id_tuple = tuple([int(i["article_id"]) for i in task_batch])
|
|
|
+
|
|
|
+ lock_rows = self.lock_task(thread_db_client, id_tuple)
|
|
|
+ if lock_rows:
|
|
|
+ prompt = category_generation_from_title(title_batch)
|
|
|
+
|
|
|
+ try:
|
|
|
+ completion = fetch_deepseek_completion(
|
|
|
+ model="DeepSeek-V3", prompt=prompt, output_type="json"
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ log(
|
|
|
+ task=self.const.TASK_NAME,
|
|
|
+ function="article_category_generation_task",
|
|
|
+ message="batch 中存在敏感词,AI 拒绝返回",
|
|
|
+ data={
|
|
|
+ "article_id": id_tuple,
|
|
|
+ "error": str(e),
|
|
|
+ "traceback": traceback.format_exc(),
|
|
|
+ },
|
|
|
+ )
|
|
|
+ for article in tqdm(task_batch):
|
|
|
+ self.deal_each_article(thread_db_client, article)
|
|
|
+
|
|
|
+ return
|
|
|
+
|
|
|
+ for article in title_batch:
|
|
|
+ self.update_title_category(thread_db_client, article[0], completion)
|
|
|
+
|
|
|
+ else:
|
|
|
+ return
|
|
|
+
|
|
|
+ def get_task_list(self):
|
|
|
+ """
|
|
|
+ get task_list from a database
|
|
|
+ """
|
|
|
+ fetch_query = f"""
|
|
|
+ select article_id, title from {self.const.ARTICLE_TABLE_NAME}
|
|
|
+ where category_status = %s and status = %s and score > %s
|
|
|
+ order by score desc limit 1000;
|
|
|
+ """
|
|
|
+ fetch_result = self.db_client.fetch(
|
|
|
+ query=fetch_query,
|
|
|
+ cursor_type=DictCursor,
|
|
|
+ params=(self.const.INIT_STATUS, self.const.ARTICLE_INIT_STATUS, self.const.LIMIT_SCORE),
|
|
|
+ )
|
|
|
+ return fetch_result
|
|
|
+
|
|
|
+ def deal(self):
|
|
|
+
|
|
|
+ self.rollback_lock_tasks(self.const.ARTICLE_TABLE_NAME)
|
|
|
+
|
|
|
+ task_list = self.get_task_list()
|
|
|
+ task_batch_list = yield_batch(data=task_list, batch_size=self.const.BATCH_SIZE)
|
|
|
+
|
|
|
+ # # dev
|
|
|
+ # for task_batch in task_batch_list:
|
|
|
+ # self.deal_batch_in_each_thread(task_batch)
|
|
|
+
|
|
|
+ with ThreadPoolExecutor(max_workers=self.const.MAX_WORKERS) as executor:
|
|
|
+ futures = [
|
|
|
+ executor.submit(self.deal_batch_in_each_thread, task_batch)
|
|
|
+ for task_batch in task_batch_list
|
|
|
+ ]
|
|
|
+
|
|
|
+ for _ in tqdm(
|
|
|
+ concurrent.futures.as_completed(futures),
|
|
|
+ total=len(futures),
|
|
|
+ desc="Processing batches",
|
|
|
+ ):
|
|
|
+ pass
|