فهرست منبع

Merge branch '2025-05-06-article-title-category-task-imporve'

luojunhui 1 ماه پیش
والد
کامیت
ed7c0f3cdc
4فایلهای تغییر یافته به همراه250 افزوده شده و 27 حذف شده
  1. 12 3
      applications/const/__init__.py
  2. 223 20
      tasks/ai_tasks/category_generation_task.py
  3. 6 0
      temp_task.py
  4. 9 4
      title_process_task.py

+ 12 - 3
applications/const/__init__.py

@@ -460,7 +460,7 @@ class CategoryGenerationTaskConst:
     const for category generation task
     """
     # MAX THREAD
-    MAX_WORKERS = 5
+    MAX_WORKERS = 10
 
     # task batch size
     BATCH_SIZE = 20
@@ -481,5 +481,14 @@ class CategoryGenerationTaskConst:
     MAX_PROCESSING_TIME = 3600
 
     # task info
-    TABLE_NAME = "publish_single_video_source"
-    TASK_NAME = "generate_category_with_title"
+    VIDEO_TABLE_NAME = "publish_single_video_source"
+    ARTICLE_TABLE_NAME = "crawler_meta_article"
+    TASK_NAME = "generate_category_with_title"
+
+    # article_status
+    ARTICLE_INIT_STATUS = 1
+    ARTICLE_PUBLISHED_STATUS = 2
+    ARTICLE_BAD_STATUS = 0
+
+    # limit score
+    LIMIT_SCORE = 0.4

+ 223 - 20
tasks/ai_tasks/category_generation_task.py

@@ -27,6 +27,31 @@ class CategoryGenerationTask:
         self.db_client.connect()
         self.const = CategoryGenerationTaskConst()
 
+    def rollback_lock_tasks(self, table_name) -> int:
+        """
+        回滚锁定的任务
+        :param table_name:
+        """
+        update_query = f"""
+               update {table_name}
+               set category_status = %s
+               where category_status = %s and category_status_update_ts <= %s;
+           """
+
+        update_rows = self.db_client.save(
+            query=update_query,
+            params=(
+                self.const.INIT_STATUS,
+                self.const.PROCESSING_STATUS,
+                int(time.time()) - self.const.MAX_PROCESSING_TIME,
+            ),
+        )
+
+        return update_rows
+
+
+class VideoPoolCategoryGenerationTask(CategoryGenerationTask):
+
     def set_category_status_as_success(
         self, thread_db_client: DatabaseConnector, article_id: int, category: str
     ) -> int:
@@ -91,25 +116,6 @@ class CategoryGenerationTask:
             )
             self.set_category_status_as_fail(thread_db_client, article_id)
 
-    def rollback_lock_tasks(self) -> int:
-
-        update_query = f"""
-            update publish_single_video_source  
-            set category_status = %s
-            where category_status = %s and category_status_update_ts <= %s;
-        """
-
-        update_rows = self.db_client.save(
-            query=update_query,
-            params=(
-                self.const.INIT_STATUS,
-                self.const.PROCESSING_STATUS,
-                int(time.time()) - self.const.MAX_PROCESSING_TIME,
-            ),
-        )
-
-        return update_rows
-
     def lock_task(
         self, thread_db_client: DatabaseConnector, article_id_tuple: tuple[int, ...]
     ) -> int:
@@ -218,7 +224,7 @@ class CategoryGenerationTask:
 
     def deal(self):
 
-        self.rollback_lock_tasks()
+        self.rollback_lock_tasks(self.const.VIDEO_TABLE_NAME)
 
         task_list = self.get_task_list()
         task_batch_list = yield_batch(data=task_list, batch_size=self.const.BATCH_SIZE)
@@ -239,3 +245,200 @@ class CategoryGenerationTask:
                 desc="Processing batches",
             ):
                 pass
+
+
+class ArticlePoolCategoryGenerationTask(CategoryGenerationTask):
+    def set_category_status_as_success(
+            self, thread_db_client: DatabaseConnector, article_id: int, category: str
+    ) -> int:
+
+        update_query = f"""
+               update {self.const.ARTICLE_TABLE_NAME}
+               set category_by_ai = %s, category_status = %s, category_status_update_ts = %s
+               where article_id = %s and category_status = %s;
+           """
+
+        update_rows = thread_db_client.save(
+            query=update_query,
+            params=(
+                category,
+                self.const.SUCCESS_STATUS,
+                int(time.time()),
+                article_id,
+                self.const.PROCESSING_STATUS,
+            ),
+        )
+        return update_rows
+
+    def set_category_status_as_fail(
+            self, thread_db_client: DatabaseConnector, article_id: int
+    ) -> int:
+
+        update_query = f"""
+               update {self.const.ARTICLE_TABLE_NAME}
+               set category_status = %s, category_status_update_ts = %s
+               where article_id = %s and category_status = %s;
+           """
+
+        update_rows = thread_db_client.save(
+            query=update_query,
+            params=(
+                self.const.FAIL_STATUS,
+                int(time.time()),
+                article_id,
+                self.const.PROCESSING_STATUS,
+            ),
+        )
+        return update_rows
+
+    def update_title_category(
+            self, thread_db_client: DatabaseConnector, article_id: int, completion: dict
+    ):
+
+        try:
+            category = completion.get(str(article_id))
+            self.set_category_status_as_success(thread_db_client, article_id, category)
+
+        except Exception as e:
+            log(
+                task=self.const.TASK_NAME,
+                function="update_each_record_status",
+                message="AI返回格式失败,更新状态为失败",
+                data={
+                    "article_id": article_id,
+                    "error": str(e),
+                    "traceback": traceback.format_exc(),
+                },
+            )
+            self.set_category_status_as_fail(thread_db_client, article_id)
+
+
+    def lock_task(
+            self, thread_db_client: DatabaseConnector, article_id_tuple: tuple[int, ...]
+    ) -> int:
+
+        update_query = f"""
+               update {self.const.ARTICLE_TABLE_NAME}
+               set category_status = %s, category_status_update_ts = %s
+               where article_id in %s and category_status = %s;
+           """
+
+        update_rows = thread_db_client.save(
+            query=update_query,
+            params=(
+                self.const.PROCESSING_STATUS,
+                int(time.time()),
+                article_id_tuple,
+                self.const.INIT_STATUS,
+            ),
+        )
+        return update_rows
+
+    def deal_each_article(self, thread_db_client, article: dict):
+        """
+        deal each article
+        """
+        article_id = article["article_id"]
+        title = article["title"]
+
+        title_batch = [(article_id, title)]
+
+        prompt = category_generation_from_title(title_batch)
+        try:
+            completion = fetch_deepseek_completion(
+                model="DeepSeek-V3", prompt=prompt, output_type="json"
+            )
+            self.update_title_category(thread_db_client, article_id, completion)
+
+        except Exception as e:
+            log(
+                task=self.const.TASK_NAME,
+                message="该文章存在敏感词,AI 拒绝返回",
+                function="deal_each_article",
+                data={
+                    "article_id": article_id,
+                    "error": str(e),
+                    "traceback": traceback.format_exc(),
+                },
+            )
+            self.set_category_status_as_fail(thread_db_client, article_id)
+
+    def deal_batch_in_each_thread(self, task_batch: list[dict]):
+        """
+        deal in each thread
+        """
+        thread_db_client = DatabaseConnector(long_articles_config)
+        thread_db_client.connect()
+
+        title_batch = [(i["article_id"], i["title"]) for i in task_batch]
+        id_tuple = tuple([int(i["article_id"]) for i in task_batch])
+
+        lock_rows = self.lock_task(thread_db_client, id_tuple)
+        if lock_rows:
+            prompt = category_generation_from_title(title_batch)
+
+            try:
+                completion = fetch_deepseek_completion(
+                    model="DeepSeek-V3", prompt=prompt, output_type="json"
+                )
+            except Exception as e:
+                log(
+                    task=self.const.TASK_NAME,
+                    function="article_category_generation_task",
+                    message="batch 中存在敏感词,AI 拒绝返回",
+                    data={
+                        "article_id": id_tuple,
+                        "error": str(e),
+                        "traceback": traceback.format_exc(),
+                    },
+                )
+                for article in tqdm(task_batch):
+                    self.deal_each_article(thread_db_client, article)
+
+                return
+
+            for article in title_batch:
+                self.update_title_category(thread_db_client, article[0], completion)
+
+        else:
+            return
+
+    def get_task_list(self):
+        """
+        get task_list from a database
+        """
+        fetch_query = f"""
+               select article_id, title from {self.const.ARTICLE_TABLE_NAME}
+               where category_status = %s and status = %s and score > %s
+               order by score desc limit 1000;
+           """
+        fetch_result = self.db_client.fetch(
+            query=fetch_query,
+            cursor_type=DictCursor,
+            params=(self.const.INIT_STATUS, self.const.ARTICLE_INIT_STATUS, self.const.LIMIT_SCORE),
+        )
+        return fetch_result
+
+    def deal(self):
+
+        self.rollback_lock_tasks(self.const.ARTICLE_TABLE_NAME)
+
+        task_list = self.get_task_list()
+        task_batch_list = yield_batch(data=task_list, batch_size=self.const.BATCH_SIZE)
+
+        # #  dev
+        # for task_batch in task_batch_list:
+        #     self.deal_batch_in_each_thread(task_batch)
+
+        with ThreadPoolExecutor(max_workers=self.const.MAX_WORKERS) as executor:
+            futures = [
+                executor.submit(self.deal_batch_in_each_thread, task_batch)
+                for task_batch in task_batch_list
+            ]
+
+            for _ in tqdm(
+                    concurrent.futures.as_completed(futures),
+                    total=len(futures),
+                    desc="Processing batches",
+            ):
+                pass

+ 6 - 0
temp_task.py

@@ -0,0 +1,6 @@
+from tasks.ai_tasks.category_generation_task import ArticlePoolCategoryGenerationTask
+
+
+if __name__ == '__main__':
+    article_pool_category_task = ArticlePoolCategoryGenerationTask()
+    article_pool_category_task.deal()

+ 9 - 4
title_process_task.py

@@ -1,7 +1,8 @@
 """
 @author: luojunhui
 """
-from tasks.ai_tasks.category_generation_task import CategoryGenerationTask
+from tasks.ai_tasks.category_generation_task import ArticlePoolCategoryGenerationTask
+from tasks.ai_tasks.category_generation_task import VideoPoolCategoryGenerationTask
 from tasks.ai_tasks.title_rewrite_task import TitleRewriteTask
 
 
@@ -10,7 +11,11 @@ if __name__ == '__main__':
     title_rewrite_task = TitleRewriteTask()
     title_rewrite_task.deal()
 
-    # 2. 标题分类
-    category_generation_task = CategoryGenerationTask()
-    category_generation_task.deal()
+    # 2. 视频内容池标题分类
+    video_pool_category_generation_task = VideoPoolCategoryGenerationTask()
+    video_pool_category_generation_task.deal()
+
+    # 3. 文章内容池标题分类
+    article_pool_category_generation_task = ArticlePoolCategoryGenerationTask()
+    article_pool_category_generation_task.deal()