""" @author luojunhui @desc find best frame from each video """ import os import datetime from tqdm import tqdm from pymysql.cursors import DictCursor from applications.api import GoogleAIAPI from applications.db import DatabaseConnector from config import long_articles_config from coldStartTasks.ai_pipeline.basic import download_file from coldStartTasks.ai_pipeline.basic import update_task_queue_status from coldStartTasks.ai_pipeline.basic import roll_back_lock_tasks from coldStartTasks.ai_pipeline.basic import extract_prompt table_name = "long_articles_new_video_cover" POOL_SIZE = 1 google_ai = GoogleAIAPI() class ExtractVideoBestFrame: """ extract video best frame from each video by GeminiAI """ def __init__(self): self.db_client = DatabaseConnector(db_config=long_articles_config) self.db_client.connect() def get_upload_task_list(self, task_num: int = 10) -> list[dict]: """ get upload task list """ fetch_query = f""" select id, video_oss_path from {table_name} where upload_status = 0 limit {task_num}; """ upload_task_list = self.db_client.fetch( query=fetch_query, cursor_type=DictCursor ) return upload_task_list def get_extract_task_list(self, task_num: int = 10) -> list[dict]: """ get extract task list """ fetch_query = f""" select id, file_name from {table_name} where upload_status = 2 and extract_status = 0 order by file_expire_time limit {task_num}; """ extract_task_list = self.db_client.fetch( query=fetch_query, cursor_type=DictCursor ) return extract_task_list def get_processing_task_pool_size(self) -> int: """ get processing task pool size """ fetch_query = f""" select count(1) as pool_size from {table_name} where upload_status = 2 and file_state = 'PROCESSING'; """ fetch_response = self.db_client.fetch(query=fetch_query, cursor_type=DictCursor) processing_task_pool_size = ( fetch_response[0]["pool_size"] if fetch_response else 0 ) return processing_task_pool_size def set_upload_result( self, task_id: int, file_name: str, file_state: str, file_expire_time: str ) -> int: update_query = f""" update {table_name} set upload_status = %s, upload_status_ts = %s, file_name = %s, file_state = %s, file_expire_time = %s where id = %s and upload_status = %s; """ update_rows = self.db_client.save( query=update_query, params=( 2, datetime.datetime.now(), file_name, file_state, file_expire_time, task_id, 1, ), ) return update_rows def set_extract_result( self, task_id: int, file_state: str, best_frame_tims_ms: str ) -> int: update_query = f""" update {table_name} set extract_status = %s, extract_status_ts = %s, file_state = %s, best_frame_time_ms = %s where id = %s and extract_status = %s; """ update_rows = self.db_client.save( query=update_query, params=( 2, datetime.datetime.now(), file_state, best_frame_tims_ms, task_id, 1, ), ) return update_rows def upload_video_to_gemini_ai( self, max_processing_pool_size: int = POOL_SIZE ) -> None: # upload video to gemini ai roll_back_lock_tasks_count = roll_back_lock_tasks( db_client=self.db_client, task="upload", init_status=0, processing_status=1, max_process_time=3600, ) print("roll_back_lock_tasks_count", roll_back_lock_tasks_count) processing_task_num = self.get_processing_task_pool_size() res_task_num = max_processing_pool_size - processing_task_num if res_task_num: upload_task_list = self.get_upload_task_list(task_num=res_task_num) for task in tqdm(upload_task_list, desc="upload_video_to_gemini_ai"): lock_status = update_task_queue_status( db_client=self.db_client, task_id=task["id"], task="upload", ori_status=0, new_status=1, ) if not lock_status: continue try: file_path = download_file(task["id"], task["video_oss_path"]) upload_response = google_ai.upload_file(file_path) if upload_response: file_name, file_state, expire_time = upload_response self.set_upload_result( task_id=task["id"], file_name=file_name, file_state=file_state, file_expire_time=expire_time, ) else: # set status as fail update_task_queue_status( db_client=self.db_client, task_id=task["id"], task="upload", ori_status=1, new_status=99, ) except Exception as e: print(f"download_file error: {e}") update_task_queue_status( db_client=self.db_client, task_id=task["id"], task="upload", ori_status=1, new_status=99, ) continue else: print("Processing task pool is full") def extract_best_frame_with_gemini_ai(self): # roll back lock tasks roll_back_lock_tasks_count = roll_back_lock_tasks( db_client=self.db_client, task="extract", init_status=0, processing_status=1, max_process_time=3600, ) print("roll_back_lock_tasks_count", roll_back_lock_tasks_count) # do extract frame task task_list = self.get_extract_task_list() for task in tqdm(task_list, desc="extract_best_frame_with_gemini_ai"): # lock task lock_status = update_task_queue_status( db_client=self.db_client, task_id=task["id"], task="extract", ori_status=0, new_status=1, ) if not lock_status: continue file_name = task["file_name"] video_local_path = "static/{}.mp4".format(task["id"]) try: google_file = google_ai.get_google_file(file_name) state = google_file.state.name match state: case "PROCESSING": # google is still processing this video update_task_queue_status( db_client=self.db_client, task_id=task["id"], task="extract", ori_status=1, new_status=0, ) print("this video is still processing") case "FAILED": # google process this video failed update_query = f""" update {table_name} set file_state = %s, extract_status = %s, extract_status_ts = %s where id = %s and extract_status = %s; """ update_rows = self.db_client.save( query=update_query, params=( "FAILED", 99, datetime.datetime.now(), task["id"], 1, ), ) case "ACTIVE": # video process successfully try: best_frame_tims_ms = google_ai.fetch_info_from_google_ai( prompt=extract_prompt, video_file=google_file ) if best_frame_tims_ms: self.set_extract_result( task_id=task["id"], file_state="ACTIVE", best_frame_tims_ms=best_frame_tims_ms, ) else: update_task_queue_status( db_client=self.db_client, task_id=task["id"], task="extract", ori_status=1, new_status=99, ) # delete local file and google file if os.path.exists(video_local_path): os.remove(video_local_path) google_ai.delete_video(file_name) except Exception as e: print(e) update_task_queue_status( db_client=self.db_client, task_id=task["id"], task="extract", ori_status=1, new_status=99, ) except Exception as e: print(f"update_task_queue_status error: {e}") update_task_queue_status( db_client=self.db_client, task_id=task["id"], task="extract", ori_status=1, new_status=99, )