Selaa lähdekoodia

generate task v2

luojunhui 4 kuukautta sitten
vanhempi
commit
5ee7d97b71

+ 1 - 1
applications/api/__init__.py

@@ -1,7 +1,7 @@
 """
 @author: luojunhui
 """
-from .deep_seek_by_byte_dance_api import get_response_by_deep_seek_api
+from .deep_seek_by_byte_dance_api import fetch_deepseek_response
 from .google_ai_api import GoogleAIAPI
 from .moon_shot_api import generate_mini_program_title
 from .nlp_api import similarity_between_title_list

+ 1 - 1
applications/api/deep_seek_by_byte_dance_api.py

@@ -8,7 +8,7 @@ from config import deep_seek_default_model
 from config import deep_seek_api_key_byte_dance
 
 
-def get_response_by_deep_seek_api(model, prompt):
+def fetch_deepseek_response(model, prompt):
     """
     deep_seek方法
     """

+ 8 - 5
applications/const/__init__.py

@@ -212,11 +212,14 @@ class VideoToTextConst:
     """
     视频转文本任务常量配置
     """
-    # extract_status 是否提取文本状态
-    EXTRACT_INIT_STATUS = 0
-    EXTRACT_PROCESSING_STATUS = 101
-    EXTRACT_SUCCESS_STATUS = 2
-    EXTRACT_FAIL_STATUS = 99
+    # SUMMARY_STATUS
+    SUMMARY_INIT_STATUS = 0
+    SUMMARY_SUCCESS_STATUS = 1
+    SUMMARY_FAIL_STATUS = 99
+    SUMMARY_LOCK = 101
+
+    # SUMMARY_TASK_BATCH_SIZE
+    SUMMARY_BATCH_SIZE = 100
 
     # bad_status 文章质量状态
     ARTICLE_GOOD_STATUS = 0

+ 23 - 34
coldStartTasks/multi_modal/generate_text_from_video.py

@@ -63,13 +63,13 @@ class GenerateTextFromVideo(object):
 
     def input_task_list(self):
         """
-        输入任务列表, 从single_video_pool中获取
+        暂时用于处理历史任务, 新的视频在插入publish_single_video_source后会直接插入video_content_understanding表中
         """
         sql = f"""
-        select article_title, concat('https://rescdn.yishihui.com/', video_oss_path ) as video_url, audit_video_id
-        from publish_single_video_source 
-        where audit_status = {const.AUDIT_SUCCESS_STATUS} and bad_status = {const.ARTICLE_GOOD_STATUS} and extract_status = {const.EXTRACT_INIT_STATUS}
-        order by id desc;
+            select article_title, concat('https://rescdn.yishihui.com/', video_oss_path ) as video_url, audit_video_id
+            from publish_single_video_source 
+            where audit_status = {const.AUDIT_SUCCESS_STATUS} and bad_status = {const.ARTICLE_GOOD_STATUS}
+            order by id desc;
         """
         task_list = self.db.fetch(sql, cursor_type=DictCursor)
         insert_sql = f"""
@@ -83,36 +83,24 @@ class GenerateTextFromVideo(object):
         )
         print(affected_rows)
 
-    def roll_back_processing_videos(self):
+    def roll_back_lock_tasks(self):
         """
-        回滚长时间处于处理中的视频
+        回滚长时间处于处理中的任务
         """
-        sql = f"""
-            select id, status_update_timestamp
-            from video_content_understanding
-            where status = {const.VIDEO_LOCK};
+        update_sql = f"""
+            update video_content_understanding 
+            set status = %s
+            where status = %s and status_update_timestamp < %s;
         """
-        task_list = self.db.fetch(sql, cursor_type=DictCursor)
-        now_timestamp = int(time.time())
-        id_list = []
-        for task in tqdm(task_list):
-            if task['status_update_timestamp']:
-                if now_timestamp - task['status_update_timestamp'] >= const.MAX_PROCESSING_TIME:
-                    id_list.append(task['id'])
-
-        if id_list:
-            update_sql = f"""
-                update video_content_understanding  
-                set status = %s
-                where id in %s;
-            """
-            self.db.save(
-                query=update_sql,
-                params=(
-                    const.VIDEO_UNDERSTAND_INIT_STATUS,
-                    tuple(id_list)
-                )
+        roll_back_rows = self.db.save(
+            query=update_sql,
+            params=(
+                const.VIDEO_UNDERSTAND_INIT_STATUS,
+                const.VIDEO_LOCK,
+                int(time.time()) - const.MAX_PROCESSING_TIME
             )
+        )
+        return roll_back_rows
 
     def update_video_status(self, ori_status, new_status, pq_vid):
         """
@@ -206,7 +194,7 @@ class GenerateTextFromVideo(object):
         """
         self.google_ai_api.delete_video(file_name)
 
-    def get_tasks(self):
+    def get_task_list(self):
         """
         获取处理视频转文本任务
         """
@@ -224,8 +212,9 @@ class GenerateTextFromVideo(object):
         """
         处理视频转文本任务
         """
-        self.roll_back_processing_videos()
-        task_list = self.get_tasks()
+        self.roll_back_lock_tasks()
+
+        task_list = self.get_task_list()
         while task_list:
             for task in tqdm(task_list, desc="convert video to text"):
                 print(task['pq_vid'], task['file_name'])

+ 92 - 71
tasks/article_summary_task.py

@@ -1,10 +1,14 @@
 """
 @author: luojunhui
 """
+
+import time
+import traceback
+
 from pymysql.cursors import DictCursor
 from tqdm import tqdm
 
-from applications.api import get_response_by_deep_seek_api
+from applications.api import fetch_deepseek_response
 from applications.const import VideoToTextConst
 from applications.db import DatabaseConnector
 from config import long_articles_config
@@ -32,6 +36,7 @@ class ArticleSummaryTask(object):
     """
     文章总结任务
     """
+
     def __init__(self):
         self.db_client = None
 
@@ -47,100 +52,116 @@ class ArticleSummaryTask(object):
         获取任务列表
         """
         select_sql = f"""
-            select t1.video_text, t2.audit_video_id
-            from video_content_understanding t1 
-                join publish_single_video_source t2 
-                on t1.pq_vid = t2.audit_video_id
-            where t1.status = {const.VIDEO_UNDERSTAND_SUCCESS_STATUS} 
-                and t2.bad_status = {const.ARTICLE_GOOD_STATUS} 
-                and t2.extract_status = {const.EXTRACT_INIT_STATUS};
+            select id, video_text
+            from video_content_understanding
+            where summary_status = {const.SUMMARY_INIT_STATUS} and status = {const.VIDEO_UNDERSTAND_SUCCESS_STATUS}
+            limit {const.SUMMARY_BATCH_SIZE};
         """
         task_list = self.db_client.fetch(select_sql, cursor_type=DictCursor)
         return task_list
 
-    def process_each_task(self, task):
+    def rollback_lock_tasks(self):
         """
-        处理每个任务
+        rollback tasks which have been locked for a long time
         """
-        video_text = task["video_text"]
-        audit_video_id = task["audit_video_id"]
-        # 开始处理,将extract_status更新为101
+        now_timestamp = int(time.time())
+        timestamp_threshold = now_timestamp - const.MAX_PROCESSING_TIME
         update_sql = f"""
-            update publish_single_video_source 
-            set extract_status = %s 
-            where audit_video_id = %s and extract_status = %s;
+            update video_content_understanding
+            set summary_status = %s
+            where summary_status = %s and status_update_timestamp < %s;
         """
-        affected_rows = self.db_client.save(
+        rollback_rows = self.db_client.save(
             query=update_sql,
-            params=(const.EXTRACT_PROCESSING_STATUS, audit_video_id, const.EXTRACT_INIT_STATUS)
+            params=(const.SUMMARY_INIT_STATUS, const.SUMMARY_LOCK, timestamp_threshold),
+        )
+
+        return rollback_rows
+
+    def handle_task_execution(self, task):
+        """
+        :param task: keys: [id, video_text]
+        """
+        task_id = task["id"]
+        video_text = task["video_text"]
+
+        # Lock Task
+        affected_rows = self.update_task_status(
+            task_id, const.SUMMARY_INIT_STATUS, const.SUMMARY_LOCK
         )
         if not affected_rows:
             return
+
         try:
-            # 生成prompt
+            # generate prompt
             prompt = generate_prompt(video_text)
-            response = get_response_by_deep_seek_api(model="DeepSeek-R1", prompt=prompt)
-            if response:
-                update_sql = f"""
-                    update publish_single_video_source 
-                    set extract_status = %s, summary_text = %s
-                    where audit_video_id = %s and extract_status = %s;
-                """
-                affected_rows = self.db_client.save(
-                    query=update_sql,
-                    params=(
-                        const.EXTRACT_SUCCESS_STATUS,
-                        response.strip(),
-                        audit_video_id,
-                        const.EXTRACT_PROCESSING_STATUS
-                    )
-                )
-                print(affected_rows)
+
+            # get result from deep seek AI
+            result = fetch_deepseek_response(model="DeepSeek-R1", prompt=prompt)
+            if result:
+                # set as success and update summary text
+                self.set_summary_text_for_task(task_id, result.strip())
             else:
-                update_sql = f"""
-                    update publish_single_video_source 
-                    set extract_status = %s
-                    where audit_video_id = %s and extract_status = %s;
-                """
-                affected_rows = self.db_client.save(
-                    query=update_sql,
-                    params=(
-                        const.EXTRACT_FAIL_STATUS,
-                        audit_video_id,
-                        const.EXTRACT_PROCESSING_STATUS
-                    )
+                # set as fail
+                self.update_task_status(
+                    task_id, const.SUMMARY_LOCK, const.SUMMARY_FAIL_STATUS
                 )
-                print(affected_rows)
         except Exception as e:
             print(e)
+            print(traceback.format_exc())
             # set as fail
-            update_sql = f"""
-                update publish_single_video_source
-                set extract_status = %s
-                where audit_video_id = %s and extract_status = %s;
-            """
-            self.db_client.save(
-                query=update_sql,
-                params=(
-                    const.EXTRACT_FAIL_STATUS,
-                    audit_video_id,
-                    const.EXTRACT_PROCESSING_STATUS
-                )
+            self.update_task_status(
+                task_id, const.SUMMARY_LOCK, const.SUMMARY_FAIL_STATUS
             )
 
+    def set_summary_text_for_task(self, task_id, text):
+        """
+        successfully get summary text and update summary text to database
+        """
+        update_sql = f"""
+            update video_content_understanding
+            set summary_status = %s, summary_text = %s, status_update_timestamp = %s
+            where id = %s and summary_status = %s;
+        """
+        affected_rows = self.db_client.save(
+            query=update_sql,
+            params=(
+                const.SUMMARY_SUCCESS_STATUS,
+                text,
+                int(time.time()),
+                task_id,
+                const.SUMMARY_LOCK,
+            ),
+        )
+        return affected_rows
+
+    def update_task_status(self, task_id, ori_status, new_status):
+        """
+        修改任务状态
+        """
+        update_sql = f"""
+            update video_content_understanding
+            set summary_status = %s, status_update_timestamp = %s
+            where id = %s and summary_status = %s;
+        """
+        update_rows = self.db_client.save(
+            update_sql, (new_status, int(time.time()), task_id, ori_status)
+        )
+        return update_rows
+
     def deal(self):
         """
-        开始处理任务
+        entrance function for this class
         """
+        # first of all rollback tasks which have been locked for a long time
+        rollback_rows = self.rollback_lock_tasks()
+        print("rollback_lock_tasks: {}".format(rollback_rows))
+
+        # get task list
         task_list = self.get_task_list()
-        for task in tqdm(task_list):
+        for task in tqdm(task_list, desc="handle each task"):
             try:
-                self.process_each_task(task)
+                self.handle_task_execution(task=task)
             except Exception as e:
-                print(e)
-                continue
-
-
-
-
-
+                print("error: {}".format(e))
+                print(traceback.format_exc())