Browse Source

开发:在冷启动阶段增加品类

luojunhui 11 hours ago
parent
commit
15178a2e7e

+ 9 - 0
applications/const/__init__.py

@@ -461,9 +461,18 @@ class CategoryGenerationTaskConst:
     """
     const for category generation task
     """
+    # MAX THREAD
+    MAX_WORKERS = 5
+
     # task batch size
     BATCH_SIZE = 20
 
+    # min batch
+    MIN_BATCH_SIZE = 1
+
+    # article status
+    ARTICLE_GOOD_STATUS = 0
+
     # task status
     INIT_STATUS = 0
     PROCESSING_STATUS = 1

+ 0 - 116
tasks/ai_tasks/category_generation.py

@@ -1,116 +0,0 @@
-"""
-generate category for given title
-"""
-import concurrent
-from concurrent.futures import ThreadPoolExecutor
-from pymysql.cursors import DictCursor
-from tqdm import tqdm
-
-from applications.api.deep_seek_api_official import fetch_deepseek_completion
-from applications.const import CategoryGenerationTaskConst
-from applications.db import DatabaseConnector
-from applications.utils import yield_batch
-from config import long_articles_config
-from tasks.ai_tasks.prompts import category_generation_from_title
-
-
-class CategoryGenerationTask:
-    """
-    generate category for given title
-    """
-    def __init__(self):
-        self.db_client = DatabaseConnector(long_articles_config)
-        self.db_client.connect()
-        self.const = CategoryGenerationTaskConst()
-
-    def set_category(self, thread_db_client, article_id, category):
-        """
-        set category for given article
-        """
-        update_query = f"""
-            update publish_single_video_source
-            set category = %s, category_status = %s
-            where id = %s and category_status = %s;
-        """
-        update_rows = thread_db_client.save(
-            query=update_query,
-            params=(category, self.const.SUCCESS_STATUS, article_id, self.const.PROCESSING_STATUS)
-        )
-        return update_rows
-
-    def lock_task(self, thread_db_client, article_id_tuple):
-        """
-        lock_task
-        """
-        update_query = f"""
-            update publish_single_video_source
-            set category_status = %s
-            where id in %s and category_status = %s;
-        """
-        update_rows = thread_db_client.save(
-            query=update_query,
-            params=(self.const.PROCESSING_STATUS, article_id_tuple, self.const.INIT_STATUS)
-        )
-        return update_rows
-
-    def deal_in_each_thread(self, task_batch):
-        try:
-            thread_db_client = DatabaseConnector(long_articles_config)
-            thread_db_client.connect()
-            title_batch = [(i['id'], i['article_title']) for i in task_batch]
-            id_tuple = tuple([i['id'] for i in task_batch])
-            # lock task
-            lock_rows = self.lock_task(thread_db_client, id_tuple)
-            if lock_rows:
-                prompt = category_generation_from_title(title_batch)
-                # print(prompt)
-                completion = fetch_deepseek_completion(model="DeepSeek-V3", prompt=prompt, output_type='json')
-                for article in title_batch:
-                    article_id = str(article[0])
-                    category = completion.get(article_id)
-                    self.set_category(thread_db_client, article_id, category)
-            else:
-                return
-        except Exception as e:
-            print(e)
-
-    def get_task_list(self):
-        """
-        get task_list from a database
-        """
-        fetch_query = f"""
-            select id, article_title from publish_single_video_source
-            where category_status = 0 and bad_status = 0
-            order by score desc;
-        """
-        fetch_result = self.db_client.fetch(fetch_query, cursor_type=DictCursor)
-        return fetch_result
-
-    def deal(self):
-        task_list = self.get_task_list()
-        task_batch_list = yield_batch(data=task_list, batch_size=self.const.BATCH_SIZE)
-        max_workers = 5
-        with ThreadPoolExecutor(max_workers=max_workers) as executor:
-            # 提交所有任务到线程池
-            futures = [executor.submit(self.deal_in_each_thread, task_batch)
-                       for task_batch in task_batch_list]
-
-            # 用 tqdm 跟踪任务完成进度
-            for _ in tqdm(
-                    concurrent.futures.as_completed(futures),
-                    total=len(futures),
-                    desc="Processing batches"
-            ):
-                pass  # 仅用于更新进度条,不需要结果
-
-
-if __name__ == '__main__':
-    category_generation_task = CategoryGenerationTask()
-    category_generation_task.deal()
-
-
-
-
-
-
-

+ 235 - 0
tasks/ai_tasks/category_generation_task.py

@@ -0,0 +1,235 @@
+"""
+generate category for given title
+"""
+import time
+import concurrent
+import traceback
+from concurrent.futures import ThreadPoolExecutor
+
+from pymysql.cursors import DictCursor
+from tqdm import tqdm
+
+from applications import log
+from applications.api.deep_seek_api_official import fetch_deepseek_completion
+from applications.const import CategoryGenerationTaskConst
+from applications.db import DatabaseConnector
+from applications.utils import yield_batch
+
+from config import long_articles_config
+from tasks.ai_tasks.prompts import category_generation_from_title
+
+
+class CategoryGenerationTask:
+
+    def __init__(self):
+        self.db_client = DatabaseConnector(long_articles_config)
+        self.db_client.connect()
+        self.const = CategoryGenerationTaskConst()
+
+    def set_category_status_as_success(
+        self, thread_db_client: DatabaseConnector, article_id: int, category: str
+    ) -> int:
+
+        update_query = f"""
+            update publish_single_video_source
+            set category = %s, category_status = %s, category_status_update_ts = %s
+            where id = %s and category_status = %s;
+        """
+
+        update_rows = thread_db_client.save(
+            query=update_query,
+            params=(
+                category,
+                self.const.SUCCESS_STATUS,
+                int(time.time()),
+                article_id,
+                self.const.PROCESSING_STATUS,
+            ),
+        )
+        return update_rows
+
+    def set_category_status_as_fail(
+        self, thread_db_client: DatabaseConnector, article_id: int
+    ) -> int:
+
+        update_query = f"""
+            update publish_single_video_source
+            set category_status = %s, category_status_update_ts = %s
+            where id = %s and category_status = %s;
+        """
+
+        update_rows = thread_db_client.save(
+            query=update_query,
+            params=(self.const.FAIL_STATUS, int(time.time()), article_id, self.const.PROCESSING_STATUS),
+        )
+        return update_rows
+
+    def update_title_category(self, thread_db_client: DatabaseConnector, article_id: int, completion: dict):
+
+        try:
+            category = completion.get(str(article_id))
+            self.set_category_status_as_success(thread_db_client, article_id, category)
+
+        except Exception as e:
+            log(
+                task=self.const.TASK_NAME,
+                function="update_each_record_status",
+                message="AI返回格式失败,更新状态为失败",
+                data={
+                    "article_id": article_id,
+                    "error": str(e),
+                    "traceback": traceback.format_exc()
+                }
+            )
+            self.set_category_status_as_fail(thread_db_client, article_id)
+
+    def rollback_lock_tasks(self) -> int:
+
+        update_query = f"""
+            update publish_single_video_source  
+            set category_status = %s
+            where category_status = %s and category_status_update_ts <= %s;
+        """
+
+        update_rows = self.db_client.save(
+            query=update_query,
+            params=(
+                self.const.INIT_STATUS,
+                self.const.PROCESSING_STATUS,
+                int(time.time()) - self.const.MAX_PROCESSING_TIME
+            )
+        )
+
+        return update_rows
+
+    def lock_task(
+        self, thread_db_client: DatabaseConnector, article_id_tuple: tuple[int, ...]
+    ) -> int:
+
+        update_query = f"""
+            update publish_single_video_source
+            set category_status = %s, category_status_update_ts = %s
+            where id in %s and category_status = %s;
+        """
+
+        update_rows = thread_db_client.save(
+            query=update_query,
+            params=(
+                self.const.PROCESSING_STATUS,
+                int(time.time()),
+                article_id_tuple,
+                self.const.INIT_STATUS,
+            ),
+        )
+        return update_rows
+
+    def deal_each_article(self, thread_db_client, article: dict):
+        """
+        deal each article
+        """
+        article_id = article["id"]
+        title = article["article_title"]
+
+        id_tuple = (article_id, )
+        title_batch = [(article_id, title)]
+
+        lock_rows = self.lock_task(thread_db_client, id_tuple)
+        if lock_rows:
+            prompt = category_generation_from_title(title_batch)
+            try:
+                completion = fetch_deepseek_completion(
+                    model="DeepSeek-V3", prompt=prompt, output_type="json"
+                )
+                self.update_title_category(thread_db_client, article_id, completion)
+
+            except Exception as e:
+                log(
+                    task=self.const.TASK_NAME,
+                    message="该文章存在敏感词,AI 拒绝返回",
+                    function="deal_each_article",
+                    data={
+                        "article_id": article_id,
+                        "error": str(e),
+                        "traceback": traceback.format_exc()
+                    }
+                )
+                self.set_category_status_as_fail(thread_db_client, article_id)
+
+    def deal_batch_in_each_thread(self, task_batch: list[dict]):
+        """
+        deal in each thread
+        """
+        thread_db_client = DatabaseConnector(long_articles_config)
+        thread_db_client.connect()
+
+        title_batch = [(i["id"], i["article_title"]) for i in task_batch]
+        id_tuple = tuple([int(i["id"]) for i in task_batch])
+
+        lock_rows = self.lock_task(thread_db_client, id_tuple)
+        if lock_rows:
+            prompt = category_generation_from_title(title_batch)
+
+            try:
+                completion = fetch_deepseek_completion(
+                    model="DeepSeek-V3", prompt=prompt, output_type="json"
+                )
+            except Exception as e:
+                log(
+                    task=self.const.TASK_NAME,
+                    function="category_generation_task",
+                    message=" batch 中存在敏感词,AI 拒绝返回",
+                    data={
+                        "article_id": id_tuple,
+                        "error": str(e),
+                        "traceback": traceback.format_exc()
+                    }
+                )
+                for article in tqdm(task_batch):
+                    self.deal_each_article(thread_db_client, article)
+
+                return
+
+            for article in title_batch:
+                self.update_title_category(thread_db_client, article[0], completion)
+
+        else:
+            return
+
+    def get_task_list(self):
+        """
+        get task_list from a database
+        """
+        fetch_query = f"""
+            select id, article_title from publish_single_video_source
+            where category_status = %s and bad_status = %s
+            order by score desc limit 20;
+        """
+        fetch_result = self.db_client.fetch(query=fetch_query, cursor_type=DictCursor, params=(self.const.INIT_STATUS, self.const.ARTICLE_GOOD_STATUS))
+        return fetch_result
+
+    def deal(self):
+
+        self.rollback_lock_tasks()
+
+        task_list = self.get_task_list()
+        task_batch_list = yield_batch(data=task_list, batch_size=self.const.BATCH_SIZE)
+        for task_batch in task_batch_list:
+            self.deal_batch_in_each_thread(task_batch)
+
+        # with ThreadPoolExecutor(max_workers=self.const.MAX_WORKERS) as executor:
+        #     futures = [
+        #         executor.submit(self.deal_in_each_thread, task_batch)
+        #         for task_batch in task_batch_list
+        #     ]
+        #
+        #     for _ in tqdm(
+        #         concurrent.futures.as_completed(futures),
+        #         total=len(futures),
+        #         desc="Processing batches",
+        #     ):
+        #         pass
+
+
+if __name__ == "__main__":
+    category_generation_task = CategoryGenerationTask()
+    category_generation_task.deal()