extract_video_best_frame.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. """
  2. @author luojunhui
  3. @desc find best frame from each video
  4. """
  5. import os
  6. import datetime
  7. from tqdm import tqdm
  8. from pymysql.cursors import DictCursor
  9. from applications.api import GoogleAIAPI
  10. from applications.db import DatabaseConnector
  11. from config import long_articles_config
  12. from coldStartTasks.ai_pipeline.basic import download_file
  13. from coldStartTasks.ai_pipeline.basic import update_task_queue_status
  14. from coldStartTasks.ai_pipeline.basic import roll_back_lock_tasks
  15. table_name = "long_articles_new_video_cover"
  16. POOL_SIZE = 1
  17. google_ai = GoogleAIAPI()
  18. class ExtractVideoBestFrame:
  19. """
  20. extract video best frame from each video by GeminiAI
  21. """
  22. def __init__(self):
  23. self.db_client = DatabaseConnector(db_config=long_articles_config)
  24. self.db_client.connect()
  25. def get_upload_task_list(self, task_num: int = 10) -> list[dict]:
  26. """
  27. get upload task list
  28. """
  29. fetch_query = f"""
  30. select id, video_oss_path from {table_name}
  31. where upload_status = 0
  32. limit {task_num};
  33. """
  34. upload_task_list = self.db_client.fetch(
  35. query=fetch_query, cursor_type=DictCursor
  36. )
  37. return upload_task_list
  38. def get_extract_task_list(self, task_num: int = 10) -> list[dict]:
  39. """
  40. get extract task list
  41. """
  42. fetch_query = f"""
  43. select id, file_name from {table_name}
  44. where upload_status = 2 and extract_status = 0
  45. order by file_expire_time
  46. limit {task_num};
  47. """
  48. extract_task_list = self.db_client.fetch(
  49. query=fetch_query, cursor_type=DictCursor
  50. )
  51. return extract_task_list
  52. def get_processing_task_pool_size(self) -> int:
  53. """
  54. get processing task pool size
  55. """
  56. fetch_query = f"""
  57. select count(1) as pool_size from {table_name}
  58. where upload_status = 2 and file_state = 'PROCESSING';
  59. """
  60. fetch_response = self.db_client.fetch(query=fetch_query, cursor_type=DictCursor)
  61. processing_task_pool_size = (
  62. fetch_response[0]["pool_size"] if fetch_response else 0
  63. )
  64. return processing_task_pool_size
  65. def set_upload_result(
  66. self, task_id: int, file_name: str, file_state: str, file_expire_time: str
  67. ) -> int:
  68. update_query = f"""
  69. update {table_name}
  70. set upload_status = %s, upload_status_ts = %s,
  71. file_name = %s, file_state = %s, file_expire_time = %s
  72. where id = %s and upload_status = %s;
  73. """
  74. update_rows = self.db_client.save(
  75. query=update_query,
  76. params=(
  77. 2,
  78. datetime.datetime.now(),
  79. file_name,
  80. file_state,
  81. file_expire_time,
  82. task_id,
  83. 1,
  84. ),
  85. )
  86. return update_rows
  87. def set_extract_result(
  88. self, task_id: int, file_state: str, best_frame_tims_ms: str
  89. ) -> int:
  90. update_query = f"""
  91. update {table_name}
  92. set extract_status = %s, extract_status_ts = %s,
  93. file_state = %s, best_frame_time_ms = %s
  94. where id = %s and extract_status = %s;
  95. """
  96. update_rows = self.db_client.save(
  97. query=update_query,
  98. params=(
  99. 2,
  100. datetime.datetime.now(),
  101. file_state,
  102. best_frame_tims_ms,
  103. task_id,
  104. 1,
  105. ),
  106. )
  107. return update_rows
  108. def upload_video_to_gemini_ai(
  109. self, max_processing_pool_size: int = POOL_SIZE
  110. ) -> None:
  111. # upload video to gemini ai
  112. roll_back_lock_tasks_count = roll_back_lock_tasks(
  113. db_client=self.db_client,
  114. task="upload",
  115. init_status=0,
  116. processing_status=1,
  117. max_process_time=3600,
  118. )
  119. print("roll_back_lock_tasks_count", roll_back_lock_tasks_count)
  120. processing_task_num = self.get_processing_task_pool_size()
  121. res_task_num = max_processing_pool_size - processing_task_num
  122. if res_task_num:
  123. upload_task_list = self.get_upload_task_list(task_num=res_task_num)
  124. for task in tqdm(upload_task_list, desc="upload_video_to_gemini_ai"):
  125. lock_status = update_task_queue_status(
  126. db_client=self.db_client,
  127. task_id=task["id"],
  128. task="upload",
  129. ori_status=0,
  130. new_status=1,
  131. )
  132. if not lock_status:
  133. continue
  134. try:
  135. file_path = download_file(task["id"], task["video_oss_path"])
  136. upload_response = google_ai.upload_file(file_path)
  137. if upload_response:
  138. file_name, file_state, expire_time = upload_response
  139. self.set_upload_result(
  140. task_id=task["id"],
  141. file_name=file_name,
  142. file_state=file_state,
  143. file_expire_time=expire_time,
  144. )
  145. else:
  146. # set status as fail
  147. update_task_queue_status(
  148. db_client=self.db_client,
  149. task_id=task["id"],
  150. task="upload",
  151. ori_status=1,
  152. new_status=99,
  153. )
  154. except Exception as e:
  155. print(f"download_file error: {e}")
  156. update_task_queue_status(
  157. db_client=self.db_client,
  158. task_id=task["id"],
  159. task="upload",
  160. ori_status=1,
  161. new_status=99,
  162. )
  163. continue
  164. else:
  165. print("Processing task pool is full")
  166. def extract_best_frame_with_gemini_ai(self):
  167. # roll back lock tasks
  168. roll_back_lock_tasks_count = roll_back_lock_tasks(
  169. db_client=self.db_client,
  170. task="extract",
  171. init_status=0,
  172. processing_status=1,
  173. max_process_time=3600,
  174. )
  175. print("roll_back_lock_tasks_count", roll_back_lock_tasks_count)
  176. # do extract frame task
  177. task_list = self.get_extract_task_list()
  178. for task in tqdm(task_list, desc="extract_best_frame_with_gemini_ai"):
  179. # lock task
  180. lock_status = update_task_queue_status(
  181. db_client=self.db_client,
  182. task_id=task["id"],
  183. task="extract",
  184. ori_status=0,
  185. new_status=1,
  186. )
  187. if not lock_status:
  188. continue
  189. file_name = task["file_name"]
  190. video_local_path = "static/{}.mp4".format(task["id"])
  191. try:
  192. google_file = google_ai.get_google_file(file_name)
  193. state = google_file.state.name
  194. match state:
  195. case "PROCESSING":
  196. # google is still processing this video
  197. update_task_queue_status(
  198. db_client=self.db_client,
  199. task_id=task["id"],
  200. task="extract",
  201. ori_status=1,
  202. new_status=0,
  203. )
  204. print("this video is still processing")
  205. case "FAILED":
  206. # google process this video failed
  207. update_query = f"""
  208. update {table_name}
  209. set file_state = %s, extract_status = %s, extract_status_ts = %s
  210. where id = %s and extract_status = %s;
  211. """
  212. update_rows = self.db_client.save(
  213. query=update_query,
  214. params=(
  215. "FAILED",
  216. 99,
  217. datetime.datetime.now(),
  218. task["id"],
  219. 1,
  220. ),
  221. )
  222. case "ACTIVE":
  223. # video process successfully
  224. try:
  225. best_frame_tims_ms = google_ai.fetch_info_from_google_ai(
  226. prompt="", video_file=google_file
  227. )
  228. if best_frame_tims_ms:
  229. self.set_extract_result(
  230. task_id=task["id"],
  231. file_state="ACTIVE",
  232. best_frame_tims_ms=best_frame_tims_ms,
  233. )
  234. else:
  235. update_task_queue_status(
  236. db_client=self.db_client,
  237. task_id=task["id"],
  238. task="extract",
  239. ori_status=1,
  240. new_status=99,
  241. )
  242. # delete local file and google file
  243. if os.path.exists(video_local_path):
  244. os.remove(video_local_path)
  245. google_ai.delete_video(file_name)
  246. except Exception as e:
  247. print(e)
  248. update_task_queue_status(
  249. db_client=self.db_client,
  250. task_id=task["id"],
  251. task="extract",
  252. ori_status=1,
  253. new_status=99,
  254. )
  255. except Exception as e:
  256. print(f"update_task_queue_status error: {e}")
  257. update_task_queue_status(
  258. db_client=self.db_client,
  259. task_id=task["id"],
  260. task="extract",
  261. ori_status=1,
  262. new_status=99,
  263. )