extract_video_best_frame.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. """
  2. @author luojunhui
  3. @desc find best frame from each video
  4. """
  5. import os
  6. import datetime
  7. from tqdm import tqdm
  8. from pymysql.cursors import DictCursor
  9. from applications.api import GoogleAIAPI
  10. from applications.const import GoogleVideoUnderstandTaskConst
  11. from applications.db import DatabaseConnector
  12. from config import long_articles_config
  13. from coldStartTasks.ai_pipeline.basic import download_file
  14. from coldStartTasks.ai_pipeline.basic import update_task_queue_status
  15. from coldStartTasks.ai_pipeline.basic import roll_back_lock_tasks
  16. from coldStartTasks.ai_pipeline.basic import extract_best_frame_prompt
  17. from coldStartTasks.ai_pipeline.basic import get_video_cover
  18. from coldStartTasks.ai_pipeline.basic import normalize_time_str
  19. const = GoogleVideoUnderstandTaskConst()
  20. table_name = "long_articles_new_video_cover"
  21. dir_name = "static"
  22. POOL_SIZE = 10
  23. google_ai = GoogleAIAPI()
  24. class ExtractVideoBestFrame:
  25. """
  26. extract video best frame from each video by GeminiAI
  27. """
  28. def __init__(self):
  29. self.db_client = DatabaseConnector(db_config=long_articles_config)
  30. self.db_client.connect()
  31. def get_upload_task_list(self, task_num: int = POOL_SIZE) -> list[dict]:
  32. """
  33. get upload task list
  34. """
  35. fetch_query = f"""
  36. select id, video_oss_path from {table_name}
  37. where upload_status = {const.INIT_STATUS} and priority = 1
  38. limit {task_num};
  39. """
  40. upload_task_list = self.db_client.fetch(
  41. query=fetch_query, cursor_type=DictCursor
  42. )
  43. return upload_task_list
  44. def get_extract_task_list(self, task_num: int = POOL_SIZE) -> list[dict]:
  45. """
  46. get extract task list
  47. """
  48. fetch_query = f"""
  49. select id, file_name from {table_name}
  50. where upload_status = {const.SUCCESS_STATUS} and extract_status = {const.INIT_STATUS}
  51. order by file_expire_time
  52. limit {task_num};
  53. """
  54. extract_task_list = self.db_client.fetch(
  55. query=fetch_query, cursor_type=DictCursor
  56. )
  57. return extract_task_list
  58. def get_cover_task_list(self) -> list[dict]:
  59. """
  60. get cover task list
  61. """
  62. fetch_query = f"""
  63. select id, video_oss_path, best_frame_time_ms from {table_name}
  64. where extract_status = {const.SUCCESS_STATUS} and get_cover_status = {const.INIT_STATUS};
  65. """
  66. extract_task_list = self.db_client.fetch(
  67. query=fetch_query, cursor_type=DictCursor
  68. )
  69. return extract_task_list
  70. def get_processing_task_pool_size(self) -> int:
  71. """
  72. get processing task pool size
  73. """
  74. fetch_query = f"""
  75. select count(1) as pool_size from {table_name}
  76. where upload_status = {const.SUCCESS_STATUS} and file_state = 'PROCESSING' and extract_status = {const.INIT_STATUS};
  77. """
  78. fetch_response = self.db_client.fetch(query=fetch_query, cursor_type=DictCursor)
  79. processing_task_pool_size = (
  80. fetch_response[0]["pool_size"] if fetch_response else 0
  81. )
  82. return processing_task_pool_size
  83. def set_upload_result(
  84. self, task_id: int, file_name: str, file_state: str, file_expire_time: str
  85. ) -> int:
  86. update_query = f"""
  87. update {table_name}
  88. set upload_status = %s, upload_status_ts = %s,
  89. file_name = %s, file_state = %s, file_expire_time = %s
  90. where id = %s and upload_status = %s;
  91. """
  92. update_rows = self.db_client.save(
  93. query=update_query,
  94. params=(
  95. const.SUCCESS_STATUS,
  96. datetime.datetime.now(),
  97. file_name,
  98. file_state,
  99. file_expire_time,
  100. task_id,
  101. const.PROCESSING_STATUS,
  102. ),
  103. )
  104. return update_rows
  105. def set_extract_result(
  106. self, task_id: int, file_state: str, best_frame_tims_ms: str
  107. ) -> int:
  108. update_query = f"""
  109. update {table_name}
  110. set extract_status = %s, extract_status_ts = %s,
  111. file_state = %s, best_frame_time_ms = %s
  112. where id = %s and extract_status = %s;
  113. """
  114. update_rows = self.db_client.save(
  115. query=update_query,
  116. params=(
  117. const.SUCCESS_STATUS,
  118. datetime.datetime.now(),
  119. file_state,
  120. best_frame_tims_ms,
  121. task_id,
  122. const.PROCESSING_STATUS,
  123. ),
  124. )
  125. return update_rows
  126. def upload_video_to_gemini_ai(
  127. self, max_processing_pool_size: int = POOL_SIZE
  128. ) -> None:
  129. # upload video to gemini ai
  130. roll_back_lock_tasks_count = roll_back_lock_tasks(
  131. db_client=self.db_client,
  132. task="upload",
  133. init_status=const.INIT_STATUS,
  134. processing_status=const.PROCESSING_STATUS,
  135. max_process_time=const.MAX_PROCESSING_TIME,
  136. )
  137. print("roll_back_lock_tasks_count", roll_back_lock_tasks_count)
  138. processing_task_num = self.get_processing_task_pool_size()
  139. res_task_num = max_processing_pool_size - processing_task_num
  140. if res_task_num:
  141. upload_task_list = self.get_upload_task_list(task_num=res_task_num)
  142. for task in tqdm(upload_task_list, desc="upload_video_to_gemini_ai"):
  143. lock_status = update_task_queue_status(
  144. db_client=self.db_client,
  145. task_id=task["id"],
  146. task="upload",
  147. ori_status=const.INIT_STATUS,
  148. new_status=const.PROCESSING_STATUS,
  149. )
  150. if not lock_status:
  151. continue
  152. try:
  153. file_path = download_file(task["id"], task["video_oss_path"])
  154. upload_response = google_ai.upload_file(file_path)
  155. if upload_response:
  156. file_name, file_state, expire_time = upload_response
  157. self.set_upload_result(
  158. task_id=task["id"],
  159. file_name=file_name,
  160. file_state=file_state,
  161. file_expire_time=expire_time,
  162. )
  163. else:
  164. # set status as fail
  165. update_task_queue_status(
  166. db_client=self.db_client,
  167. task_id=task["id"],
  168. task="upload",
  169. ori_status=const.PROCESSING_STATUS,
  170. new_status=const.FAIL_STATUS,
  171. )
  172. except Exception as e:
  173. print(f"download_file error: {e}")
  174. update_task_queue_status(
  175. db_client=self.db_client,
  176. task_id=task["id"],
  177. task="upload",
  178. ori_status=const.PROCESSING_STATUS,
  179. new_status=const.FAIL_STATUS,
  180. )
  181. continue
  182. else:
  183. print("Processing task pool is full")
  184. def extract_best_frame_with_gemini_ai(self):
  185. # roll back lock tasks
  186. roll_back_lock_tasks_count = roll_back_lock_tasks(
  187. db_client=self.db_client,
  188. task="extract",
  189. init_status=const.INIT_STATUS,
  190. processing_status=const.PROCESSING_STATUS,
  191. max_process_time=const.MAX_PROCESSING_TIME,
  192. )
  193. print("roll_back_lock_tasks_count", roll_back_lock_tasks_count)
  194. # do extract frame task
  195. task_list = self.get_extract_task_list()
  196. for task in tqdm(task_list, desc="extract_best_frame_with_gemini_ai"):
  197. # lock task
  198. lock_status = update_task_queue_status(
  199. db_client=self.db_client,
  200. task_id=task["id"],
  201. task="extract",
  202. ori_status=const.INIT_STATUS,
  203. new_status=const.PROCESSING_STATUS,
  204. )
  205. if not lock_status:
  206. continue
  207. file_name = task["file_name"]
  208. video_local_path = os.path.join(dir_name, "{}.mp4".format(task["id"]))
  209. try:
  210. google_file = google_ai.get_google_file(file_name)
  211. state = google_file.state.name
  212. match state:
  213. case "PROCESSING":
  214. # google is still processing this video
  215. update_task_queue_status(
  216. db_client=self.db_client,
  217. task_id=task["id"],
  218. task="extract",
  219. ori_status=const.PROCESSING_STATUS,
  220. new_status=const.INIT_STATUS,
  221. )
  222. print("this video is still processing")
  223. case "FAILED":
  224. # google process this video failed
  225. update_query = f"""
  226. update {table_name}
  227. set file_state = %s, extract_status = %s, extract_status_ts = %s
  228. where id = %s and extract_status = %s;
  229. """
  230. update_rows = self.db_client.save(
  231. query=update_query,
  232. params=(
  233. "FAILED",
  234. const.FAIL_STATUS,
  235. datetime.datetime.now(),
  236. task["id"],
  237. const.PROCESSING_STATUS,
  238. ),
  239. )
  240. case "ACTIVE":
  241. # video process successfully
  242. try:
  243. best_frame_tims_ms = google_ai.fetch_info_from_google_ai(
  244. prompt=extract_best_frame_prompt(),
  245. video_file=google_file,
  246. )
  247. if best_frame_tims_ms:
  248. self.set_extract_result(
  249. task_id=task["id"],
  250. file_state="ACTIVE",
  251. best_frame_tims_ms=best_frame_tims_ms.strip(),
  252. )
  253. else:
  254. update_task_queue_status(
  255. db_client=self.db_client,
  256. task_id=task["id"],
  257. task="extract",
  258. ori_status=const.PROCESSING_STATUS,
  259. new_status=const.FAIL_STATUS,
  260. )
  261. # delete local file and google file
  262. if os.path.exists(video_local_path):
  263. os.remove(video_local_path)
  264. google_ai.delete_video(file_name)
  265. except Exception as e:
  266. print(e)
  267. update_task_queue_status(
  268. db_client=self.db_client,
  269. task_id=task["id"],
  270. task="extract",
  271. ori_status=const.PROCESSING_STATUS,
  272. new_status=const.FAIL_STATUS,
  273. )
  274. except Exception as e:
  275. print(f"update_task_queue_status error: {e}")
  276. update_task_queue_status(
  277. db_client=self.db_client,
  278. task_id=task["id"],
  279. task="extract",
  280. ori_status=const.PROCESSING_STATUS,
  281. new_status=const.FAIL_STATUS,
  282. )
  283. def get_cover_with_best_frame(self):
  284. """
  285. get cover with best frame
  286. """
  287. # get task list
  288. task_list = self.get_cover_task_list()
  289. for task in tqdm(task_list, desc="extract_cover_with_ffmpeg"):
  290. # lock task
  291. lock_status = update_task_queue_status(
  292. db_client=self.db_client,
  293. task_id=task["id"],
  294. task="get_cover",
  295. ori_status=const.INIT_STATUS,
  296. new_status=const.PROCESSING_STATUS,
  297. )
  298. if not lock_status:
  299. continue
  300. time_str = normalize_time_str(task["best_frame_time_ms"])
  301. if time_str:
  302. response = get_video_cover(
  303. video_oss_path=task["video_oss_path"], time_millisecond_str=time_str
  304. )
  305. print(response)
  306. if response["success"] and response["data"]:
  307. cover_oss_path = response["data"]
  308. update_query = f"""
  309. update {table_name}
  310. set cover_oss_path = %s, get_cover_status = %s, get_cover_status_ts = %s
  311. where id = %s and get_cover_status = %s;
  312. """
  313. update_rows = self.db_client.save(
  314. query=update_query,
  315. params=(
  316. cover_oss_path,
  317. const.SUCCESS_STATUS,
  318. datetime.datetime.now(),
  319. task["id"],
  320. const.PROCESSING_STATUS,
  321. ),
  322. )
  323. else:
  324. update_task_queue_status(
  325. db_client=self.db_client,
  326. task_id=task["id"],
  327. task="get_cover",
  328. ori_status=const.PROCESSING_STATUS,
  329. new_status=const.FAIL_STATUS,
  330. )
  331. else:
  332. update_task_queue_status(
  333. db_client=self.db_client,
  334. task_id=task["id"],
  335. task="get_cover",
  336. ori_status=const.PROCESSING_STATUS,
  337. new_status=const.FAIL_STATUS,
  338. )