luojunhui 3 주 전
부모
커밋
593636b024
3개의 변경된 파일304개의 추가작업 그리고 38개의 파일을 삭제
  1. 1 1
      applications/api/google_ai_api.py
  2. 17 37
      coldStartTasks/ai_pipeline/basic.py
  3. 286 0
      coldStartTasks/ai_pipeline/extract_video_best_frame.py

+ 1 - 1
applications/api/google_ai_api.py

@@ -55,7 +55,7 @@ class GoogleAIAPI(object):
             print(e)
             return None
 
-    def get_video_text(self, prompt, video_file):
+    def fetch_info_from_google_ai(self, prompt, video_file):
         """
         获取视频文本
         prompt: 提示词

+ 17 - 37
coldStartTasks/ai_pipeline/basic.py

@@ -6,28 +6,25 @@ import requests
 from applications.db import DatabaseConnector
 
 
-def get_status_field_by_process(process: str) -> tuple[str, str]:
-    match process:
+def get_status_field_by_task(task: str) -> tuple[str, str]:
+    match task:
         case "upload":
             status = "upload_status"
             update_timestamp = "upload_status_ts"
-        case "understanding":
-            status = "understanding_status"
-            update_timestamp = "understanding_status_ts"
-        case "summary":
-            status = "summary_status"
-            update_timestamp = "summary_status_ts"
-        case "rewrite":
-            status = "rewrite_status"
-            update_timestamp = "rewrite_status_ts"
+        case "extract":
+            status = "extract_status"
+            update_timestamp = "extract_status_ts"
+        case "get_cover":
+            status = "get_cover_status"
+            update_timestamp = "get_cover_status_ts"
         case _:
-            raise ValueError(f"Unexpected task: {process}")
+            raise ValueError(f"Unexpected task: {task}")
     return status, update_timestamp
 
 
 def roll_back_lock_tasks(
     db_client: DatabaseConnector,
-    process: str,
+    task: str,
     max_process_time: int,
     init_status: int,
     processing_status: int
@@ -35,11 +32,11 @@ def roll_back_lock_tasks(
     """
     rollback tasks which have been locked for a long time
     """
-    status, update_timestamp = get_status_field_by_process(process)
+    status, update_timestamp = get_status_field_by_task(task)
     now_timestamp = int(time.time())
     timestamp_threshold = now_timestamp - max_process_time
     update_query = f"""
-        update video_content_understanding
+        update long_articles_new_video_cover
         set {status} = %s
         where {status} = %s and {update_timestamp} < %s;
     """
@@ -83,17 +80,17 @@ def generate_summary_prompt(text):
 def update_task_queue_status(
         db_client: DatabaseConnector,
         task_id: int,
-        process: str,
+        task: str,
         ori_status: int,
         new_status: int) -> int:
     # update task queue status
-    status, update_timestamp = get_status_field_by_process(process)
+    status, update_timestamp = get_status_field_by_task(task)
     update_query = f"""
-        update video_content_understanding 
+        update long_articles_new_video_cover
         set {status} = %s, {update_timestamp} = %s
         where {status} = %s and id = %s;
     """
-    roll_back_rows = db_client.save(
+    update_rows = db_client.save(
         query=update_query,
         params=(
             new_status,
@@ -102,21 +99,4 @@ def update_task_queue_status(
             task_id,
         ),
     )
-    return roll_back_rows
-
-
-def update_video_pool_status(
-        db_client: DatabaseConnector,
-        content_trace_id: str,
-        ori_status: int,
-        new_status: int) -> int:
-    # update publish_single_source_status
-    update_query = f"""
-                update publish_single_video_source
-                set status = %s
-                where content_trace_id = %s and status = %s
-            """
-    affected_rows = db_client.save(
-        query=update_query, params=(new_status, content_trace_id, ori_status)
-    )
-    return affected_rows
+    return update_rows

+ 286 - 0
coldStartTasks/ai_pipeline/extract_video_best_frame.py

@@ -0,0 +1,286 @@
+"""
+@author luojunhui
+@desc find best frame from each video
+"""
+
+import os
+import datetime
+from tqdm import tqdm
+from pymysql.cursors import DictCursor
+
+from applications.api import GoogleAIAPI
+from applications.db import DatabaseConnector
+from config import long_articles_config
+from coldStartTasks.ai_pipeline.basic import download_file
+from coldStartTasks.ai_pipeline.basic import update_task_queue_status
+from coldStartTasks.ai_pipeline.basic import roll_back_lock_tasks
+
+table_name = "long_articles_new_video_cover"
+POOL_SIZE = 1
+google_ai = GoogleAIAPI()
+
+
+class ExtractVideoBestFrame:
+    """
+    extract video best frame from each video by GeminiAI
+    """
+
+    def __init__(self):
+        self.db_client = DatabaseConnector(db_config=long_articles_config)
+        self.db_client.connect()
+
+    def get_upload_task_list(self, task_num: int = 10) -> list[dict]:
+        """
+        get upload task list
+        """
+        fetch_query = f"""
+            select id, video_oss_path from {table_name} 
+            where upload_status = 0
+            limit {task_num};
+        """
+        upload_task_list = self.db_client.fetch(
+            query=fetch_query, cursor_type=DictCursor
+        )
+        return upload_task_list
+
+    def get_extract_task_list(self, task_num: int = 10) -> list[dict]:
+        """
+        get extract task list
+        """
+        fetch_query = f"""
+            select id, file_name from {table_name}
+            where upload_status = 2 and extract_status = 0
+            order by file_expire_time
+            limit {task_num};
+        """
+        extract_task_list = self.db_client.fetch(
+            query=fetch_query, cursor_type=DictCursor
+        )
+        return extract_task_list
+
+    def get_processing_task_pool_size(self) -> int:
+        """
+        get processing task pool size
+        """
+        fetch_query = f"""
+            select count(1) as pool_size from {table_name}
+            where upload_status = 2 and file_state = 'PROCESSING';
+        """
+        fetch_response = self.db_client.fetch(query=fetch_query, cursor_type=DictCursor)
+        processing_task_pool_size = (
+            fetch_response[0]["pool_size"] if fetch_response else 0
+        )
+        return processing_task_pool_size
+
+    def set_upload_result(
+        self, task_id: int, file_name: str, file_state: str, file_expire_time: str
+    ) -> int:
+        update_query = f"""
+            update {table_name} 
+            set upload_status = %s, upload_status_ts = %s,
+                file_name = %s, file_state = %s, file_expire_time = %s
+            where id = %s and upload_status = %s;
+        """
+        update_rows = self.db_client.save(
+            query=update_query,
+            params=(
+                2,
+                datetime.datetime.now(),
+                file_name,
+                file_state,
+                file_expire_time,
+                task_id,
+                1,
+            ),
+        )
+        return update_rows
+
+    def set_extract_result(
+        self, task_id: int, file_state: str, best_frame_tims_ms: str
+    ) -> int:
+        update_query = f"""
+            update {table_name} 
+            set extract_status = %s, extract_status_ts = %s,
+                file_state = %s, best_frame_time_ms = %s
+            where id = %s and extract_status = %s;
+        """
+        update_rows = self.db_client.save(
+            query=update_query,
+            params=(
+                2,
+                datetime.datetime.now(),
+                file_state,
+                best_frame_tims_ms,
+                task_id,
+                1,
+            ),
+        )
+        return update_rows
+
+    def upload_video_to_gemini_ai(
+        self, max_processing_pool_size: int = POOL_SIZE
+    ) -> None:
+        # upload video to gemini ai
+        roll_back_lock_tasks_count = roll_back_lock_tasks(
+            db_client=self.db_client,
+            task="upload",
+            init_status=0,
+            processing_status=1,
+            max_process_time=3600,
+        )
+        print("roll_back_lock_tasks_count", roll_back_lock_tasks_count)
+
+        processing_task_num = self.get_processing_task_pool_size()
+        res_task_num = max_processing_pool_size - processing_task_num
+        if res_task_num:
+            upload_task_list = self.get_upload_task_list(task_num=res_task_num)
+            for task in tqdm(upload_task_list, desc="upload_video_to_gemini_ai"):
+                lock_status = update_task_queue_status(
+                    db_client=self.db_client,
+                    task_id=task["id"],
+                    task="upload",
+                    ori_status=0,
+                    new_status=1,
+                )
+                if not lock_status:
+                    continue
+
+                try:
+                    file_path = download_file(task["id"], task["video_oss_path"])
+                    upload_response = google_ai.upload_file(file_path)
+                    if upload_response:
+                        file_name, file_state, expire_time = upload_response
+                        self.set_upload_result(
+                            task_id=task["id"],
+                            file_name=file_name,
+                            file_state=file_state,
+                            file_expire_time=expire_time,
+                        )
+                    else:
+                        # set status as fail
+                        update_task_queue_status(
+                            db_client=self.db_client,
+                            task_id=task["id"],
+                            task="upload",
+                            ori_status=1,
+                            new_status=99,
+                        )
+                except Exception as e:
+                    print(f"download_file error: {e}")
+                    update_task_queue_status(
+                        db_client=self.db_client,
+                        task_id=task["id"],
+                        task="upload",
+                        ori_status=1,
+                        new_status=99,
+                    )
+                    continue
+
+        else:
+            print("Processing task pool is full")
+
+    def extract_best_frame_with_gemini_ai(self):
+        # roll back lock tasks
+        roll_back_lock_tasks_count = roll_back_lock_tasks(
+            db_client=self.db_client,
+            task="extract",
+            init_status=0,
+            processing_status=1,
+            max_process_time=3600,
+        )
+        print("roll_back_lock_tasks_count", roll_back_lock_tasks_count)
+
+        # do extract frame task
+        task_list = self.get_extract_task_list()
+        for task in tqdm(task_list, desc="extract_best_frame_with_gemini_ai"):
+            # lock task
+            lock_status = update_task_queue_status(
+                db_client=self.db_client,
+                task_id=task["id"],
+                task="extract",
+                ori_status=0,
+                new_status=1,
+            )
+            if not lock_status:
+                continue
+
+            file_name = task["file_name"]
+            video_local_path = "static/{}.mp4".format(task["id"])
+            try:
+                google_file = google_ai.get_google_file(file_name)
+                state = google_file.state.name
+
+                match state:
+                    case "PROCESSING":
+                        # google is still processing this video
+                        update_task_queue_status(
+                            db_client=self.db_client,
+                            task_id=task["id"],
+                            task="extract",
+                            ori_status=1,
+                            new_status=0,
+                        )
+                        print("this video is still processing")
+
+                    case "FAILED":
+                        # google process this video failed
+                        update_query = f"""
+                            update {table_name}
+                            set file_state = %s, extract_status = %s, extract_status_ts = %s
+                            where id = %s and extract_status = %s;
+                        """
+                        update_rows = self.db_client.save(
+                            query=update_query,
+                            params=(
+                                "FAILED",
+                                99,
+                                datetime.datetime.now(),
+                                task["id"],
+                                1,
+                            ),
+                        )
+
+                    case "ACTIVE":
+                        # video process successfully
+                        try:
+                            best_frame_tims_ms = google_ai.fetch_info_from_google_ai(
+                                prompt="", video_file=google_file
+                            )
+                            if best_frame_tims_ms:
+                                self.set_extract_result(
+                                    task_id=task["id"],
+                                    file_state="ACTIVE",
+                                    best_frame_tims_ms=best_frame_tims_ms,
+                                )
+                            else:
+                                update_task_queue_status(
+                                    db_client=self.db_client,
+                                    task_id=task["id"],
+                                    task="extract",
+                                    ori_status=1,
+                                    new_status=99,
+                                )
+                            # delete local file and google file
+                            if os.path.exists(video_local_path):
+                                os.remove(video_local_path)
+
+                            google_ai.delete_video(file_name)
+                        except Exception as e:
+                            print(e)
+                            update_task_queue_status(
+                                db_client=self.db_client,
+                                task_id=task["id"],
+                                task="extract",
+                                ori_status=1,
+                                new_status=99,
+                            )
+
+            except Exception as e:
+                print(f"update_task_queue_status error: {e}")
+                update_task_queue_status(
+                    db_client=self.db_client,
+                    task_id=task["id"],
+                    task="extract",
+                    ori_status=1,
+                    new_status=99,
+                )