فهرست منبع

开发:在冷启动阶段增加品类

luojunhui 10 ساعت پیش
والد
کامیت
27bb11f56d
4فایلهای تغییر یافته به همراه74 افزوده شده و 61 حذف شده
  1. 0 9
      run_title_rewrite_task.py
  2. 1 1
      sh/run_title_rewrite_task.sh
  3. 57 51
      tasks/ai_tasks/category_generation_task.py
  4. 16 0
      title_process_task.py

+ 0 - 9
run_title_rewrite_task.py

@@ -1,9 +0,0 @@
-"""
-@author: luojunhui
-"""
-from tasks.title_rewrite_task import TitleRewriteTask
-
-
-if __name__ == '__main__':
-    task = TitleRewriteTask()
-    task.deal()

+ 1 - 1
sh/run_title_rewrite_task.sh

@@ -21,6 +21,6 @@ else
     conda activate tasks
 
     # 在后台运行 Python 脚本并重定向日志输出
-    nohup python3 run_title_rewrite_task.py >> "${LOG_FILE}" 2>&1 &
+    nohup python3 title_process_task.py >> "${LOG_FILE}" 2>&1 &
     echo "$(date '+%Y-%m-%d %H:%M:%S') - successfully restarted run_title_rewrite_task.py"
 fi

+ 57 - 51
tasks/ai_tasks/category_generation_task.py

@@ -1,6 +1,7 @@
 """
 generate category for given title
 """
+
 import time
 import concurrent
 import traceback
@@ -60,11 +61,18 @@ class CategoryGenerationTask:
 
         update_rows = thread_db_client.save(
             query=update_query,
-            params=(self.const.FAIL_STATUS, int(time.time()), article_id, self.const.PROCESSING_STATUS),
+            params=(
+                self.const.FAIL_STATUS,
+                int(time.time()),
+                article_id,
+                self.const.PROCESSING_STATUS,
+            ),
         )
         return update_rows
 
-    def update_title_category(self, thread_db_client: DatabaseConnector, article_id: int, completion: dict):
+    def update_title_category(
+        self, thread_db_client: DatabaseConnector, article_id: int, completion: dict
+    ):
 
         try:
             category = completion.get(str(article_id))
@@ -78,8 +86,8 @@ class CategoryGenerationTask:
                 data={
                     "article_id": article_id,
                     "error": str(e),
-                    "traceback": traceback.format_exc()
-                }
+                    "traceback": traceback.format_exc(),
+                },
             )
             self.set_category_status_as_fail(thread_db_client, article_id)
 
@@ -96,8 +104,8 @@ class CategoryGenerationTask:
             params=(
                 self.const.INIT_STATUS,
                 self.const.PROCESSING_STATUS,
-                int(time.time()) - self.const.MAX_PROCESSING_TIME
-            )
+                int(time.time()) - self.const.MAX_PROCESSING_TIME,
+            ),
         )
 
         return update_rows
@@ -130,30 +138,27 @@ class CategoryGenerationTask:
         article_id = article["id"]
         title = article["article_title"]
 
-        id_tuple = (article_id, )
         title_batch = [(article_id, title)]
 
-        lock_rows = self.lock_task(thread_db_client, id_tuple)
-        if lock_rows:
-            prompt = category_generation_from_title(title_batch)
-            try:
-                completion = fetch_deepseek_completion(
-                    model="DeepSeek-V3", prompt=prompt, output_type="json"
-                )
-                self.update_title_category(thread_db_client, article_id, completion)
+        prompt = category_generation_from_title(title_batch)
+        try:
+            completion = fetch_deepseek_completion(
+                model="DeepSeek-V3", prompt=prompt, output_type="json"
+            )
+            self.update_title_category(thread_db_client, article_id, completion)
 
-            except Exception as e:
-                log(
-                    task=self.const.TASK_NAME,
-                    message="该文章存在敏感词,AI 拒绝返回",
-                    function="deal_each_article",
-                    data={
-                        "article_id": article_id,
-                        "error": str(e),
-                        "traceback": traceback.format_exc()
-                    }
-                )
-                self.set_category_status_as_fail(thread_db_client, article_id)
+        except Exception as e:
+            log(
+                task=self.const.TASK_NAME,
+                message="该文章存在敏感词,AI 拒绝返回",
+                function="deal_each_article",
+                data={
+                    "article_id": article_id,
+                    "error": str(e),
+                    "traceback": traceback.format_exc(),
+                },
+            )
+            self.set_category_status_as_fail(thread_db_client, article_id)
 
     def deal_batch_in_each_thread(self, task_batch: list[dict]):
         """
@@ -181,8 +186,8 @@ class CategoryGenerationTask:
                     data={
                         "article_id": id_tuple,
                         "error": str(e),
-                        "traceback": traceback.format_exc()
-                    }
+                        "traceback": traceback.format_exc(),
+                    },
                 )
                 for article in tqdm(task_batch):
                     self.deal_each_article(thread_db_client, article)
@@ -202,9 +207,13 @@ class CategoryGenerationTask:
         fetch_query = f"""
             select id, article_title from publish_single_video_source
             where category_status = %s and bad_status = %s
-            order by score desc limit 20;
+            order by score desc;
         """
-        fetch_result = self.db_client.fetch(query=fetch_query, cursor_type=DictCursor, params=(self.const.INIT_STATUS, self.const.ARTICLE_GOOD_STATUS))
+        fetch_result = self.db_client.fetch(
+            query=fetch_query,
+            cursor_type=DictCursor,
+            params=(self.const.INIT_STATUS, self.const.ARTICLE_GOOD_STATUS),
+        )
         return fetch_result
 
     def deal(self):
@@ -213,23 +222,20 @@ class CategoryGenerationTask:
 
         task_list = self.get_task_list()
         task_batch_list = yield_batch(data=task_list, batch_size=self.const.BATCH_SIZE)
-        for task_batch in task_batch_list:
-            self.deal_batch_in_each_thread(task_batch)
-
-        # with ThreadPoolExecutor(max_workers=self.const.MAX_WORKERS) as executor:
-        #     futures = [
-        #         executor.submit(self.deal_in_each_thread, task_batch)
-        #         for task_batch in task_batch_list
-        #     ]
-        #
-        #     for _ in tqdm(
-        #         concurrent.futures.as_completed(futures),
-        #         total=len(futures),
-        #         desc="Processing batches",
-        #     ):
-        #         pass
-
-
-if __name__ == "__main__":
-    category_generation_task = CategoryGenerationTask()
-    category_generation_task.deal()
+
+        ##  dev
+        # for task_batch in task_batch_list:
+        #     self.deal_batch_in_each_thread(task_batch)
+
+        with ThreadPoolExecutor(max_workers=self.const.MAX_WORKERS) as executor:
+            futures = [
+                executor.submit(self.deal_batch_in_each_thread, task_batch)
+                for task_batch in task_batch_list
+            ]
+
+            for _ in tqdm(
+                concurrent.futures.as_completed(futures),
+                total=len(futures),
+                desc="Processing batches",
+            ):
+                pass

+ 16 - 0
title_process_task.py

@@ -0,0 +1,16 @@
+"""
+@author: luojunhui
+"""
+from tasks.ai_tasks.category_generation_task import CategoryGenerationTask
+from tasks.title_rewrite_task import TitleRewriteTask
+
+
+if __name__ == '__main__':
+    # 1. 标题重写
+    title_rewrite_task = TitleRewriteTask()
+    title_rewrite_task.deal()
+
+    # 2. 标题分类
+    category_generation_task = CategoryGenerationTask()
+    category_generation_task.deal()
+