extract_video_best_frame.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. """
  2. @author luojunhui
  3. @desc find best frame from each video
  4. """
  5. import os
  6. import datetime
  7. from tqdm import tqdm
  8. from pymysql.cursors import DictCursor
  9. from applications.api import GoogleAIAPI
  10. from applications.db import DatabaseConnector
  11. from config import long_articles_config
  12. from coldStartTasks.ai_pipeline.basic import download_file
  13. from coldStartTasks.ai_pipeline.basic import update_task_queue_status
  14. from coldStartTasks.ai_pipeline.basic import roll_back_lock_tasks
  15. from coldStartTasks.ai_pipeline.basic import extract_prompt
  16. table_name = "long_articles_new_video_cover"
  17. POOL_SIZE = 1
  18. google_ai = GoogleAIAPI()
  19. class ExtractVideoBestFrame:
  20. """
  21. extract video best frame from each video by GeminiAI
  22. """
  23. def __init__(self):
  24. self.db_client = DatabaseConnector(db_config=long_articles_config)
  25. self.db_client.connect()
  26. def get_upload_task_list(self, task_num: int = 10) -> list[dict]:
  27. """
  28. get upload task list
  29. """
  30. fetch_query = f"""
  31. select id, video_oss_path from {table_name}
  32. where upload_status = 0
  33. limit {task_num};
  34. """
  35. upload_task_list = self.db_client.fetch(
  36. query=fetch_query, cursor_type=DictCursor
  37. )
  38. return upload_task_list
  39. def get_extract_task_list(self, task_num: int = 10) -> list[dict]:
  40. """
  41. get extract task list
  42. """
  43. fetch_query = f"""
  44. select id, file_name from {table_name}
  45. where upload_status = 2 and extract_status = 0
  46. order by file_expire_time
  47. limit {task_num};
  48. """
  49. extract_task_list = self.db_client.fetch(
  50. query=fetch_query, cursor_type=DictCursor
  51. )
  52. return extract_task_list
  53. def get_processing_task_pool_size(self) -> int:
  54. """
  55. get processing task pool size
  56. """
  57. fetch_query = f"""
  58. select count(1) as pool_size from {table_name}
  59. where upload_status = 2 and file_state = 'PROCESSING';
  60. """
  61. fetch_response = self.db_client.fetch(query=fetch_query, cursor_type=DictCursor)
  62. processing_task_pool_size = (
  63. fetch_response[0]["pool_size"] if fetch_response else 0
  64. )
  65. return processing_task_pool_size
  66. def set_upload_result(
  67. self, task_id: int, file_name: str, file_state: str, file_expire_time: str
  68. ) -> int:
  69. update_query = f"""
  70. update {table_name}
  71. set upload_status = %s, upload_status_ts = %s,
  72. file_name = %s, file_state = %s, file_expire_time = %s
  73. where id = %s and upload_status = %s;
  74. """
  75. update_rows = self.db_client.save(
  76. query=update_query,
  77. params=(
  78. 2,
  79. datetime.datetime.now(),
  80. file_name,
  81. file_state,
  82. file_expire_time,
  83. task_id,
  84. 1,
  85. ),
  86. )
  87. return update_rows
  88. def set_extract_result(
  89. self, task_id: int, file_state: str, best_frame_tims_ms: str
  90. ) -> int:
  91. update_query = f"""
  92. update {table_name}
  93. set extract_status = %s, extract_status_ts = %s,
  94. file_state = %s, best_frame_time_ms = %s
  95. where id = %s and extract_status = %s;
  96. """
  97. update_rows = self.db_client.save(
  98. query=update_query,
  99. params=(
  100. 2,
  101. datetime.datetime.now(),
  102. file_state,
  103. best_frame_tims_ms,
  104. task_id,
  105. 1,
  106. ),
  107. )
  108. return update_rows
  109. def upload_video_to_gemini_ai(
  110. self, max_processing_pool_size: int = POOL_SIZE
  111. ) -> None:
  112. # upload video to gemini ai
  113. roll_back_lock_tasks_count = roll_back_lock_tasks(
  114. db_client=self.db_client,
  115. task="upload",
  116. init_status=0,
  117. processing_status=1,
  118. max_process_time=3600,
  119. )
  120. print("roll_back_lock_tasks_count", roll_back_lock_tasks_count)
  121. processing_task_num = self.get_processing_task_pool_size()
  122. res_task_num = max_processing_pool_size - processing_task_num
  123. if res_task_num:
  124. upload_task_list = self.get_upload_task_list(task_num=res_task_num)
  125. for task in tqdm(upload_task_list, desc="upload_video_to_gemini_ai"):
  126. lock_status = update_task_queue_status(
  127. db_client=self.db_client,
  128. task_id=task["id"],
  129. task="upload",
  130. ori_status=0,
  131. new_status=1,
  132. )
  133. if not lock_status:
  134. continue
  135. try:
  136. file_path = download_file(task["id"], task["video_oss_path"])
  137. upload_response = google_ai.upload_file(file_path)
  138. if upload_response:
  139. file_name, file_state, expire_time = upload_response
  140. self.set_upload_result(
  141. task_id=task["id"],
  142. file_name=file_name,
  143. file_state=file_state,
  144. file_expire_time=expire_time,
  145. )
  146. else:
  147. # set status as fail
  148. update_task_queue_status(
  149. db_client=self.db_client,
  150. task_id=task["id"],
  151. task="upload",
  152. ori_status=1,
  153. new_status=99,
  154. )
  155. except Exception as e:
  156. print(f"download_file error: {e}")
  157. update_task_queue_status(
  158. db_client=self.db_client,
  159. task_id=task["id"],
  160. task="upload",
  161. ori_status=1,
  162. new_status=99,
  163. )
  164. continue
  165. else:
  166. print("Processing task pool is full")
  167. def extract_best_frame_with_gemini_ai(self):
  168. # roll back lock tasks
  169. roll_back_lock_tasks_count = roll_back_lock_tasks(
  170. db_client=self.db_client,
  171. task="extract",
  172. init_status=0,
  173. processing_status=1,
  174. max_process_time=3600,
  175. )
  176. print("roll_back_lock_tasks_count", roll_back_lock_tasks_count)
  177. # do extract frame task
  178. task_list = self.get_extract_task_list()
  179. for task in tqdm(task_list, desc="extract_best_frame_with_gemini_ai"):
  180. # lock task
  181. lock_status = update_task_queue_status(
  182. db_client=self.db_client,
  183. task_id=task["id"],
  184. task="extract",
  185. ori_status=0,
  186. new_status=1,
  187. )
  188. if not lock_status:
  189. continue
  190. file_name = task["file_name"]
  191. video_local_path = "static/{}.mp4".format(task["id"])
  192. try:
  193. google_file = google_ai.get_google_file(file_name)
  194. state = google_file.state.name
  195. match state:
  196. case "PROCESSING":
  197. # google is still processing this video
  198. update_task_queue_status(
  199. db_client=self.db_client,
  200. task_id=task["id"],
  201. task="extract",
  202. ori_status=1,
  203. new_status=0,
  204. )
  205. print("this video is still processing")
  206. case "FAILED":
  207. # google process this video failed
  208. update_query = f"""
  209. update {table_name}
  210. set file_state = %s, extract_status = %s, extract_status_ts = %s
  211. where id = %s and extract_status = %s;
  212. """
  213. update_rows = self.db_client.save(
  214. query=update_query,
  215. params=(
  216. "FAILED",
  217. 99,
  218. datetime.datetime.now(),
  219. task["id"],
  220. 1,
  221. ),
  222. )
  223. case "ACTIVE":
  224. # video process successfully
  225. try:
  226. best_frame_tims_ms = google_ai.fetch_info_from_google_ai(
  227. prompt=extract_prompt, video_file=google_file
  228. )
  229. if best_frame_tims_ms:
  230. self.set_extract_result(
  231. task_id=task["id"],
  232. file_state="ACTIVE",
  233. best_frame_tims_ms=best_frame_tims_ms,
  234. )
  235. else:
  236. update_task_queue_status(
  237. db_client=self.db_client,
  238. task_id=task["id"],
  239. task="extract",
  240. ori_status=1,
  241. new_status=99,
  242. )
  243. # delete local file and google file
  244. if os.path.exists(video_local_path):
  245. os.remove(video_local_path)
  246. google_ai.delete_video(file_name)
  247. except Exception as e:
  248. print(e)
  249. update_task_queue_status(
  250. db_client=self.db_client,
  251. task_id=task["id"],
  252. task="extract",
  253. ori_status=1,
  254. new_status=99,
  255. )
  256. except Exception as e:
  257. print(f"update_task_queue_status error: {e}")
  258. update_task_queue_status(
  259. db_client=self.db_client,
  260. task_id=task["id"],
  261. task="extract",
  262. ori_status=1,
  263. new_status=99,
  264. )