فهرست منبع

开发:在冷启动阶段增加品类

luojunhui 1 روز پیش
والد
کامیت
9af920d53c
4فایلهای تغییر یافته به همراه97 افزوده شده و 16 حذف شده
  1. 1 1
      applications/api/deep_seek_api_official.py
  2. 21 0
      applications/const/__init__.py
  3. 72 12
      tasks/ai_tasks/category_generation.py
  4. 3 3
      tasks/ai_tasks/prompts.py

+ 1 - 1
applications/api/deep_seek_api_official.py

@@ -7,7 +7,7 @@ from openai import OpenAI
 from config import deep_seek_official_model
 from config import deep_seek_official_api_key
 
-def fetch_deepseek_response(model, prompt, output_type='text'):
+def fetch_deepseek_completion(model, prompt, output_type='text'):
     """
     deep_seek方法
     """

+ 21 - 0
applications/const/__init__.py

@@ -455,3 +455,24 @@ class GoogleVideoUnderstandTaskConst:
     TASK_NAME = "extract_video_best_frame_as_cover"
     DIR_NAME = "static"
     POOL_SIZE = 15
+
+
+class CategoryGenerationTaskConst:
+    """
+    const for category generation task
+    """
+    # task batch size
+    BATCH_SIZE = 20
+
+    # task status
+    INIT_STATUS = 0
+    PROCESSING_STATUS = 1
+    SUCCESS_STATUS = 2
+    FAIL_STATUS = 99
+
+    # max processing time
+    MAX_PROCESSING_TIME = 3600
+
+    # task info
+    TABLE_NAME = "publish_single_video_source"
+    TASK_NAME = "generate_category_with_title"

+ 72 - 12
tasks/ai_tasks/category_generation.py

@@ -1,14 +1,19 @@
 """
 generate category for given title
 """
+import concurrent
+from concurrent.futures import ThreadPoolExecutor
 from pymysql.cursors import DictCursor
+from tqdm import tqdm
 
-from applications.api import deep_seek_api_by_volcanoengine
+from applications.api.deep_seek_api_official import fetch_deepseek_completion
+from applications.const import CategoryGenerationTaskConst
 from applications.db import DatabaseConnector
 from applications.utils import yield_batch
 from config import long_articles_config
 from tasks.ai_tasks.prompts import category_generation_from_title
 
+
 class CategoryGenerationTask:
     """
     generate category for given title
@@ -16,14 +21,58 @@ class CategoryGenerationTask:
     def __init__(self):
         self.db_client = DatabaseConnector(long_articles_config)
         self.db_client.connect()
+        self.const = CategoryGenerationTaskConst()
 
-    def deal_in_each_thread(self, task_batch):
-        thread_db_client = DatabaseConnector(long_articles_config)
-        thread_db_client.connect()
+    def set_category(self, thread_db_client, article_id, category):
+        """
+        set category for given article
+        """
+        update_query = f"""
+            update publish_single_video_source
+            set category = %s, category_status = %s
+            where id = %s and category_status = %s;
+        """
+        update_rows = thread_db_client.save(
+            query=update_query,
+            params=(category, self.const.SUCCESS_STATUS, article_id, self.const.PROCESSING_STATUS)
+        )
+        return update_rows
 
-        task_id_batch = [i['id'] for i in task_batch]
-        title_batch = [i['title'] for i in task_batch]
-        prompt = category_generation_from_title(title_batch)
+    def lock_task(self, thread_db_client, article_id_tuple):
+        """
+        lock_task
+        """
+        update_query = f"""
+            update publish_single_video_source
+            set category_status = %s
+            where id in %s and category_status = %s;
+        """
+        update_rows = thread_db_client.save(
+            query=update_query,
+            params=(self.const.PROCESSING_STATUS, article_id_tuple, self.const.INIT_STATUS)
+        )
+        return update_rows
+
+    def deal_in_each_thread(self, task_batch):
+        try:
+            thread_db_client = DatabaseConnector(long_articles_config)
+            thread_db_client.connect()
+            title_batch = [(i['id'], i['article_title']) for i in task_batch]
+            id_tuple = tuple([i['id'] for i in task_batch])
+            # lock task
+            lock_rows = self.lock_task(thread_db_client, id_tuple)
+            if lock_rows:
+                prompt = category_generation_from_title(title_batch)
+                # print(prompt)
+                completion = fetch_deepseek_completion(model="DeepSeek-V3", prompt=prompt, output_type='json')
+                for article in title_batch:
+                    article_id = str(article[0])
+                    category = completion.get(article_id)
+                    self.set_category(thread_db_client, article_id, category)
+            else:
+                return
+        except Exception as e:
+            print(e)
 
     def get_task_list(self):
         """
@@ -31,17 +80,28 @@ class CategoryGenerationTask:
         """
         fetch_query = f"""
             select id, article_title from publish_single_video_source
-            where category_status = 0 
-            order by score desc limit 100;
+            where category_status = 0 and bad_status = 0
+            order by score desc;
         """
         fetch_result = self.db_client.fetch(fetch_query, cursor_type=DictCursor)
         return fetch_result
 
     def deal(self):
         task_list = self.get_task_list()
-        task_batch_list = yield_batch(data=task_list, batch_size=20)
-        for task_batch in task_batch_list:
-            self.deal_in_each_thread(task_batch)
+        task_batch_list = yield_batch(data=task_list, batch_size=self.const.BATCH_SIZE)
+        max_workers = 5
+        with ThreadPoolExecutor(max_workers=max_workers) as executor:
+            # 提交所有任务到线程池
+            futures = [executor.submit(self.deal_in_each_thread, task_batch)
+                       for task_batch in task_batch_list]
+
+            # 用 tqdm 跟踪任务完成进度
+            for _ in tqdm(
+                    concurrent.futures.as_completed(futures),
+                    total=len(futures),
+                    desc="Processing batches"
+            ):
+                pass  # 仅用于更新进度条,不需要结果
 
 
 if __name__ == '__main__':

+ 3 - 3
tasks/ai_tasks/prompts.py

@@ -6,7 +6,6 @@ def category_generation_from_title(title_list):
     """
     generate prompt category for given title
     """
-    title_lines = "\n".join(title_list)
     prompt = f"""
         请帮我完成以下任务:输入为文章的标题,根据标题判断其内容所属的类目,输出为文章标题及其对应的类目。
         类目需从以下15个品类内选择:
@@ -116,9 +115,10 @@ def category_generation_from_title(title_list):
         外卖时代将被终结?一个全新行业正悄悄取代外卖,你准备好了吗?
         准备存款的一定要知道,今明两年,定期存款要记住“4不存”
         
-        最后输出结果请用JSON格式输出,key为标题,value为品类,仅输出JSON,不要markdown格式,不要任何其他内容
+        输入是一个 LIST, LIST 中的每个元素是一个元组,元组的第一个元素是文章的 ID,第二个元素是文章的标题。
+        最后输出结果请用JSON格式输出,key为ID,value为品类,仅输出JSON,不要markdown格式,不要任何其他内容
         如果标题中包含半角双引号,则进行转义
-        以下是需要分析的文章标题列表,每一行是一个标题:{title_lines}
+        输入的 LIST 是 {title_list}
     """
     return prompt