|
@@ -0,0 +1,286 @@
|
|
|
+"""
|
|
|
+@author luojunhui
|
|
|
+@desc find best frame from each video
|
|
|
+"""
|
|
|
+
|
|
|
+import os
|
|
|
+import datetime
|
|
|
+from tqdm import tqdm
|
|
|
+from pymysql.cursors import DictCursor
|
|
|
+
|
|
|
+from applications.api import GoogleAIAPI
|
|
|
+from applications.db import DatabaseConnector
|
|
|
+from config import long_articles_config
|
|
|
+from coldStartTasks.ai_pipeline.basic import download_file
|
|
|
+from coldStartTasks.ai_pipeline.basic import update_task_queue_status
|
|
|
+from coldStartTasks.ai_pipeline.basic import roll_back_lock_tasks
|
|
|
+
|
|
|
+table_name = "long_articles_new_video_cover"
|
|
|
+POOL_SIZE = 1
|
|
|
+google_ai = GoogleAIAPI()
|
|
|
+
|
|
|
+
|
|
|
+class ExtractVideoBestFrame:
|
|
|
+ """
|
|
|
+ extract video best frame from each video by GeminiAI
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self):
|
|
|
+ self.db_client = DatabaseConnector(db_config=long_articles_config)
|
|
|
+ self.db_client.connect()
|
|
|
+
|
|
|
+ def get_upload_task_list(self, task_num: int = 10) -> list[dict]:
|
|
|
+ """
|
|
|
+ get upload task list
|
|
|
+ """
|
|
|
+ fetch_query = f"""
|
|
|
+ select id, video_oss_path from {table_name}
|
|
|
+ where upload_status = 0
|
|
|
+ limit {task_num};
|
|
|
+ """
|
|
|
+ upload_task_list = self.db_client.fetch(
|
|
|
+ query=fetch_query, cursor_type=DictCursor
|
|
|
+ )
|
|
|
+ return upload_task_list
|
|
|
+
|
|
|
+ def get_extract_task_list(self, task_num: int = 10) -> list[dict]:
|
|
|
+ """
|
|
|
+ get extract task list
|
|
|
+ """
|
|
|
+ fetch_query = f"""
|
|
|
+ select id, file_name from {table_name}
|
|
|
+ where upload_status = 2 and extract_status = 0
|
|
|
+ order by file_expire_time
|
|
|
+ limit {task_num};
|
|
|
+ """
|
|
|
+ extract_task_list = self.db_client.fetch(
|
|
|
+ query=fetch_query, cursor_type=DictCursor
|
|
|
+ )
|
|
|
+ return extract_task_list
|
|
|
+
|
|
|
+ def get_processing_task_pool_size(self) -> int:
|
|
|
+ """
|
|
|
+ get processing task pool size
|
|
|
+ """
|
|
|
+ fetch_query = f"""
|
|
|
+ select count(1) as pool_size from {table_name}
|
|
|
+ where upload_status = 2 and file_state = 'PROCESSING';
|
|
|
+ """
|
|
|
+ fetch_response = self.db_client.fetch(query=fetch_query, cursor_type=DictCursor)
|
|
|
+ processing_task_pool_size = (
|
|
|
+ fetch_response[0]["pool_size"] if fetch_response else 0
|
|
|
+ )
|
|
|
+ return processing_task_pool_size
|
|
|
+
|
|
|
+ def set_upload_result(
|
|
|
+ self, task_id: int, file_name: str, file_state: str, file_expire_time: str
|
|
|
+ ) -> int:
|
|
|
+ update_query = f"""
|
|
|
+ update {table_name}
|
|
|
+ set upload_status = %s, upload_status_ts = %s,
|
|
|
+ file_name = %s, file_state = %s, file_expire_time = %s
|
|
|
+ where id = %s and upload_status = %s;
|
|
|
+ """
|
|
|
+ update_rows = self.db_client.save(
|
|
|
+ query=update_query,
|
|
|
+ params=(
|
|
|
+ 2,
|
|
|
+ datetime.datetime.now(),
|
|
|
+ file_name,
|
|
|
+ file_state,
|
|
|
+ file_expire_time,
|
|
|
+ task_id,
|
|
|
+ 1,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ return update_rows
|
|
|
+
|
|
|
+ def set_extract_result(
|
|
|
+ self, task_id: int, file_state: str, best_frame_tims_ms: str
|
|
|
+ ) -> int:
|
|
|
+ update_query = f"""
|
|
|
+ update {table_name}
|
|
|
+ set extract_status = %s, extract_status_ts = %s,
|
|
|
+ file_state = %s, best_frame_time_ms = %s
|
|
|
+ where id = %s and extract_status = %s;
|
|
|
+ """
|
|
|
+ update_rows = self.db_client.save(
|
|
|
+ query=update_query,
|
|
|
+ params=(
|
|
|
+ 2,
|
|
|
+ datetime.datetime.now(),
|
|
|
+ file_state,
|
|
|
+ best_frame_tims_ms,
|
|
|
+ task_id,
|
|
|
+ 1,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ return update_rows
|
|
|
+
|
|
|
+ def upload_video_to_gemini_ai(
|
|
|
+ self, max_processing_pool_size: int = POOL_SIZE
|
|
|
+ ) -> None:
|
|
|
+ # upload video to gemini ai
|
|
|
+ roll_back_lock_tasks_count = roll_back_lock_tasks(
|
|
|
+ db_client=self.db_client,
|
|
|
+ task="upload",
|
|
|
+ init_status=0,
|
|
|
+ processing_status=1,
|
|
|
+ max_process_time=3600,
|
|
|
+ )
|
|
|
+ print("roll_back_lock_tasks_count", roll_back_lock_tasks_count)
|
|
|
+
|
|
|
+ processing_task_num = self.get_processing_task_pool_size()
|
|
|
+ res_task_num = max_processing_pool_size - processing_task_num
|
|
|
+ if res_task_num:
|
|
|
+ upload_task_list = self.get_upload_task_list(task_num=res_task_num)
|
|
|
+ for task in tqdm(upload_task_list, desc="upload_video_to_gemini_ai"):
|
|
|
+ lock_status = update_task_queue_status(
|
|
|
+ db_client=self.db_client,
|
|
|
+ task_id=task["id"],
|
|
|
+ task="upload",
|
|
|
+ ori_status=0,
|
|
|
+ new_status=1,
|
|
|
+ )
|
|
|
+ if not lock_status:
|
|
|
+ continue
|
|
|
+
|
|
|
+ try:
|
|
|
+ file_path = download_file(task["id"], task["video_oss_path"])
|
|
|
+ upload_response = google_ai.upload_file(file_path)
|
|
|
+ if upload_response:
|
|
|
+ file_name, file_state, expire_time = upload_response
|
|
|
+ self.set_upload_result(
|
|
|
+ task_id=task["id"],
|
|
|
+ file_name=file_name,
|
|
|
+ file_state=file_state,
|
|
|
+ file_expire_time=expire_time,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ # set status as fail
|
|
|
+ update_task_queue_status(
|
|
|
+ db_client=self.db_client,
|
|
|
+ task_id=task["id"],
|
|
|
+ task="upload",
|
|
|
+ ori_status=1,
|
|
|
+ new_status=99,
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ print(f"download_file error: {e}")
|
|
|
+ update_task_queue_status(
|
|
|
+ db_client=self.db_client,
|
|
|
+ task_id=task["id"],
|
|
|
+ task="upload",
|
|
|
+ ori_status=1,
|
|
|
+ new_status=99,
|
|
|
+ )
|
|
|
+ continue
|
|
|
+
|
|
|
+ else:
|
|
|
+ print("Processing task pool is full")
|
|
|
+
|
|
|
+ def extract_best_frame_with_gemini_ai(self):
|
|
|
+ # roll back lock tasks
|
|
|
+ roll_back_lock_tasks_count = roll_back_lock_tasks(
|
|
|
+ db_client=self.db_client,
|
|
|
+ task="extract",
|
|
|
+ init_status=0,
|
|
|
+ processing_status=1,
|
|
|
+ max_process_time=3600,
|
|
|
+ )
|
|
|
+ print("roll_back_lock_tasks_count", roll_back_lock_tasks_count)
|
|
|
+
|
|
|
+ # do extract frame task
|
|
|
+ task_list = self.get_extract_task_list()
|
|
|
+ for task in tqdm(task_list, desc="extract_best_frame_with_gemini_ai"):
|
|
|
+ # lock task
|
|
|
+ lock_status = update_task_queue_status(
|
|
|
+ db_client=self.db_client,
|
|
|
+ task_id=task["id"],
|
|
|
+ task="extract",
|
|
|
+ ori_status=0,
|
|
|
+ new_status=1,
|
|
|
+ )
|
|
|
+ if not lock_status:
|
|
|
+ continue
|
|
|
+
|
|
|
+ file_name = task["file_name"]
|
|
|
+ video_local_path = "static/{}.mp4".format(task["id"])
|
|
|
+ try:
|
|
|
+ google_file = google_ai.get_google_file(file_name)
|
|
|
+ state = google_file.state.name
|
|
|
+
|
|
|
+ match state:
|
|
|
+ case "PROCESSING":
|
|
|
+ # google is still processing this video
|
|
|
+ update_task_queue_status(
|
|
|
+ db_client=self.db_client,
|
|
|
+ task_id=task["id"],
|
|
|
+ task="extract",
|
|
|
+ ori_status=1,
|
|
|
+ new_status=0,
|
|
|
+ )
|
|
|
+ print("this video is still processing")
|
|
|
+
|
|
|
+ case "FAILED":
|
|
|
+ # google process this video failed
|
|
|
+ update_query = f"""
|
|
|
+ update {table_name}
|
|
|
+ set file_state = %s, extract_status = %s, extract_status_ts = %s
|
|
|
+ where id = %s and extract_status = %s;
|
|
|
+ """
|
|
|
+ update_rows = self.db_client.save(
|
|
|
+ query=update_query,
|
|
|
+ params=(
|
|
|
+ "FAILED",
|
|
|
+ 99,
|
|
|
+ datetime.datetime.now(),
|
|
|
+ task["id"],
|
|
|
+ 1,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ case "ACTIVE":
|
|
|
+ # video process successfully
|
|
|
+ try:
|
|
|
+ best_frame_tims_ms = google_ai.fetch_info_from_google_ai(
|
|
|
+ prompt="", video_file=google_file
|
|
|
+ )
|
|
|
+ if best_frame_tims_ms:
|
|
|
+ self.set_extract_result(
|
|
|
+ task_id=task["id"],
|
|
|
+ file_state="ACTIVE",
|
|
|
+ best_frame_tims_ms=best_frame_tims_ms,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ update_task_queue_status(
|
|
|
+ db_client=self.db_client,
|
|
|
+ task_id=task["id"],
|
|
|
+ task="extract",
|
|
|
+ ori_status=1,
|
|
|
+ new_status=99,
|
|
|
+ )
|
|
|
+ # delete local file and google file
|
|
|
+ if os.path.exists(video_local_path):
|
|
|
+ os.remove(video_local_path)
|
|
|
+
|
|
|
+ google_ai.delete_video(file_name)
|
|
|
+ except Exception as e:
|
|
|
+ print(e)
|
|
|
+ update_task_queue_status(
|
|
|
+ db_client=self.db_client,
|
|
|
+ task_id=task["id"],
|
|
|
+ task="extract",
|
|
|
+ ori_status=1,
|
|
|
+ new_status=99,
|
|
|
+ )
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"update_task_queue_status error: {e}")
|
|
|
+ update_task_queue_status(
|
|
|
+ db_client=self.db_client,
|
|
|
+ task_id=task["id"],
|
|
|
+ task="extract",
|
|
|
+ ori_status=1,
|
|
|
+ new_status=99,
|
|
|
+ )
|