extract_video_best_frame.py 17 KB


  1. """
  2. @author luojunhui
  3. @desc find best frame from each video
  4. """
  5. import os
  6. import datetime
  7. import traceback
  8. from tqdm import tqdm
  9. from pymysql.cursors import DictCursor
  10. from applications import log
  11. from applications.api import GoogleAIAPI
  12. from applications.const import GoogleVideoUnderstandTaskConst
  13. from applications.db import DatabaseConnector
  14. from config import long_articles_config
  15. from coldStartTasks.ai_pipeline.basic import download_file
  16. from coldStartTasks.ai_pipeline.basic import update_task_queue_status
  17. from coldStartTasks.ai_pipeline.basic import roll_back_lock_tasks
  18. from coldStartTasks.ai_pipeline.basic import extract_best_frame_prompt
  19. from coldStartTasks.ai_pipeline.basic import get_video_cover
  20. from coldStartTasks.ai_pipeline.basic import normalize_time_str
  21. const = GoogleVideoUnderstandTaskConst()
  22. google_ai = GoogleAIAPI()
  23. class ExtractVideoBestFrame:
  24. """
  25. extract video best frame from each video by GeminiAI
  26. """
  27. def __init__(self):
  28. self.db_client = DatabaseConnector(db_config=long_articles_config)
  29. self.db_client.connect()
  30. def _roll_back_lock_tasks(self, task: str) -> int:
  31. return roll_back_lock_tasks(
  32. db_client=self.db_client,
  33. task=task,
  34. init_status=const.INIT_STATUS,
  35. processing_status=const.PROCESSING_STATUS,
  36. max_process_time=const.MAX_PROCESSING_TIME,
  37. )
  38. def _lock_task(self, task_id: int, task_name) -> int:
  39. return update_task_queue_status(
  40. db_client=self.db_client,
  41. task_id=task_id,
  42. task=task_name,
  43. ori_status=const.INIT_STATUS,
  44. new_status=const.PROCESSING_STATUS,
  45. )
  46. def get_upload_task_list(self, task_num: int = const.POOL_SIZE) -> list[dict]:
  47. """
  48. get upload task list
  49. """
  50. fetch_query = f"""
  51. select id, video_oss_path from {const.TABLE_NAME}
  52. where upload_status = {const.INIT_STATUS}
  53. order by priority desc
  54. limit {task_num};
  55. """
  56. upload_task_list = self.db_client.fetch(
  57. query=fetch_query, cursor_type=DictCursor
  58. )
  59. return upload_task_list
  60. def get_extract_task_list(self, task_num: int = const.POOL_SIZE) -> list[dict]:
  61. """
  62. get extract task list
  63. """
  64. fetch_query = f"""
  65. select id, file_name from {const.TABLE_NAME}
  66. where upload_status = {const.SUCCESS_STATUS} and extract_status = {const.INIT_STATUS}
  67. order by file_expire_time
  68. limit {task_num};
  69. """
  70. extract_task_list = self.db_client.fetch(
  71. query=fetch_query, cursor_type=DictCursor
  72. )
  73. return extract_task_list
  74. def get_cover_task_list(self) -> list[dict]:
  75. """
  76. get cover task list
  77. """
  78. fetch_query = f"""
  79. select id, video_oss_path, best_frame_time_ms from {const.TABLE_NAME}
  80. where extract_status = {const.SUCCESS_STATUS} and get_cover_status = {const.INIT_STATUS};
  81. """
  82. extract_task_list = self.db_client.fetch(
  83. query=fetch_query, cursor_type=DictCursor
  84. )
  85. return extract_task_list
  86. def get_processing_task_pool_size(self) -> int:
  87. """
  88. get processing task pool size
  89. """
  90. fetch_query = f"""
  91. select count(1) as pool_size from {const.TABLE_NAME}
  92. where upload_status = {const.SUCCESS_STATUS} and file_state = 'PROCESSING' and extract_status = {const.INIT_STATUS};
  93. """
  94. fetch_response = self.db_client.fetch(query=fetch_query, cursor_type=DictCursor)
  95. processing_task_pool_size = (
  96. fetch_response[0]["pool_size"] if fetch_response else 0
  97. )
  98. return processing_task_pool_size
  99. def set_upload_result(
  100. self, task_id: int, file_name: str, file_state: str, file_expire_time: str
  101. ) -> int:
  102. update_query = f"""
  103. update {const.TABLE_NAME}
  104. set upload_status = %s, upload_status_ts = %s,
  105. file_name = %s, file_state = %s, file_expire_time = %s
  106. where id = %s and upload_status = %s;
  107. """
  108. update_rows = self.db_client.save(
  109. query=update_query,
  110. params=(
  111. const.SUCCESS_STATUS,
  112. datetime.datetime.now(),
  113. file_name,
  114. file_state,
  115. file_expire_time,
  116. task_id,
  117. const.PROCESSING_STATUS,
  118. ),
  119. )
  120. return update_rows
  121. def set_extract_result(
  122. self, task_id: int, file_state: str, best_frame_time_ms: str
  123. ) -> int:
  124. update_query = f"""
  125. update {const.TABLE_NAME}
  126. set extract_status = %s, extract_status_ts = %s,
  127. file_state = %s, best_frame_time_ms = %s
  128. where id = %s and extract_status = %s;
  129. """
  130. update_rows = self.db_client.save(
  131. query=update_query,
  132. params=(
  133. const.SUCCESS_STATUS,
  134. datetime.datetime.now(),
  135. file_state,
  136. best_frame_time_ms,
  137. task_id,
  138. const.PROCESSING_STATUS,
  139. ),
  140. )
  141. return update_rows
  142. def set_cover_result(self, task_id: int, cover_oss_path: str) -> int:
  143. update_query = f"""
  144. update {const.TABLE_NAME}
  145. set cover_oss_path = %s, get_cover_status = %s, get_cover_status_ts = %s
  146. where id = %s and get_cover_status = %s;
  147. """
  148. update_rows = self.db_client.save(
  149. query=update_query,
  150. params=(
  151. cover_oss_path,
  152. const.SUCCESS_STATUS,
  153. datetime.datetime.now(),
  154. task_id,
  155. const.PROCESSING_STATUS,
  156. ),
  157. )
  158. return update_rows
  159. def upload_each_video(self, task: dict) -> None:
  160. lock_status = self._lock_task(task_id=task["id"], task_name="upload")
  161. if not lock_status:
  162. return None
  163. try:
  164. file_path = download_file(task["id"], task["video_oss_path"])
  165. upload_response = google_ai.upload_file(file_path)
  166. if upload_response:
  167. file_name, file_state, expire_time = upload_response
  168. self.set_upload_result(
  169. task_id=task["id"],
  170. file_name=file_name,
  171. file_state=file_state,
  172. file_expire_time=expire_time,
  173. )
  174. return None
  175. else:
  176. # set status as fail
  177. update_task_queue_status(
  178. db_client=self.db_client,
  179. task_id=task["id"],
  180. task="upload",
  181. ori_status=const.PROCESSING_STATUS,
  182. new_status=const.FAIL_STATUS,
  183. )
  184. return None
  185. except Exception as e:
  186. log(
  187. task=const.TASK_NAME,
  188. function="upload_video_to_gemini_ai",
  189. message="task_failed",
  190. data={
  191. "task_id": task["id"],
  192. "track_back": traceback.format_exc(),
  193. "error": str(e),
  194. },
  195. )
  196. update_task_queue_status(
  197. db_client=self.db_client,
  198. task_id=task["id"],
  199. task="upload",
  200. ori_status=const.PROCESSING_STATUS,
  201. new_status=const.FAIL_STATUS,
  202. )
  203. return None
  204. def upload_video_to_gemini_ai(
  205. self, max_processing_pool_size: int = const.POOL_SIZE
  206. ) -> None:
  207. # upload video to gemini ai
  208. roll_back_lock_tasks_count = self._roll_back_lock_tasks(task="upload")
  209. log(
  210. task=const.TASK_NAME,
  211. function="upload_video_to_gemini_ai",
  212. message=f"roll_back_lock_tasks_count: {roll_back_lock_tasks_count}",
  213. )
  214. processing_task_num = self.get_processing_task_pool_size()
  215. res_task_num = max_processing_pool_size - processing_task_num
  216. if res_task_num:
  217. upload_task_list = self.get_upload_task_list(task_num=res_task_num)
  218. for task in tqdm(upload_task_list, desc="upload_video_to_gemini_ai"):
  219. self.upload_each_video(task=task)
  220. else:
  221. log(
  222. task=const.TASK_NAME,
  223. function="upload_video_to_gemini_ai",
  224. message="reach pool size, no more space for task to upload",
  225. )
  226. def extract_each_video(self, task: dict) -> None:
  227. # lock task
  228. lock_status = self._lock_task(task_id=task["id"], task_name="extract")
  229. if not lock_status:
  230. return None
  231. file_name = task["file_name"]
  232. video_local_path = os.path.join(const.DIR_NAME, "{}.mp4".format(task["id"]))
  233. try:
  234. google_file = google_ai.get_google_file(file_name)
  235. state = google_file.state.name
  236. match state:
  237. case "PROCESSING":
  238. # google is still processing this video
  239. update_task_queue_status(
  240. db_client=self.db_client,
  241. task_id=task["id"],
  242. task="extract",
  243. ori_status=const.PROCESSING_STATUS,
  244. new_status=const.INIT_STATUS,
  245. )
  246. log(
  247. task=const.TASK_NAME,
  248. function="extract_best_frame_with_gemini_ai",
  249. message="google is still processing this video",
  250. data={
  251. "task_id": task["id"],
  252. "file_name": file_name,
  253. "state": state,
  254. },
  255. )
  256. case "FAILED":
  257. # google process this video failed
  258. update_query = f"""
  259. update {const.TABLE_NAME}
  260. set file_state = %s, extract_status = %s, extract_status_ts = %s
  261. where id = %s and extract_status = %s;
  262. """
  263. self.db_client.save(
  264. query=update_query,
  265. params=(
  266. "FAILED",
  267. const.FAIL_STATUS,
  268. datetime.datetime.now(),
  269. task["id"],
  270. const.PROCESSING_STATUS,
  271. ),
  272. )
  273. log(
  274. task=const.TASK_NAME,
  275. function="extract_best_frame_with_gemini_ai",
  276. message="google process this video failed",
  277. data={
  278. "task_id": task["id"],
  279. "file_name": file_name,
  280. "state": state,
  281. },
  282. )
  283. case "ACTIVE":
  284. # video process successfully
  285. try:
  286. best_frame_time_ms = google_ai.fetch_info_from_google_ai(
  287. prompt=extract_best_frame_prompt(),
  288. video_file=google_file,
  289. )
  290. if best_frame_time_ms:
  291. self.set_extract_result(
  292. task_id=task["id"],
  293. file_state="ACTIVE",
  294. best_frame_time_ms=best_frame_time_ms.strip(),
  295. )
  296. else:
  297. update_task_queue_status(
  298. db_client=self.db_client,
  299. task_id=task["id"],
  300. task="extract",
  301. ori_status=const.PROCESSING_STATUS,
  302. new_status=const.FAIL_STATUS,
  303. )
  304. # delete local file and google file
  305. if os.path.exists(video_local_path):
  306. os.remove(video_local_path)
  307. google_ai.delete_video(file_name)
  308. log(
  309. task=const.TASK_NAME,
  310. function="extract_best_frame_with_gemini_ai",
  311. message="video process successfully",
  312. data={
  313. "task_id": task["id"],
  314. "file_name": file_name,
  315. "state": state,
  316. "best_frame_time_ms": best_frame_time_ms,
  317. },
  318. )
  319. except Exception as e:
  320. log(
  321. task=const.TASK_NAME,
  322. function="extract_best_frame_with_gemini_ai",
  323. message="task_failed_inside_cycle",
  324. data={
  325. "task_id": task["id"],
  326. "track_back": traceback.format_exc(),
  327. "error": str(e),
  328. },
  329. )
  330. update_task_queue_status(
  331. db_client=self.db_client,
  332. task_id=task["id"],
  333. task="extract",
  334. ori_status=const.PROCESSING_STATUS,
  335. new_status=const.FAIL_STATUS,
  336. )
  337. except Exception as e:
  338. log(
  339. task=const.TASK_NAME,
  340. function="extract_best_frame_with_gemini_ai",
  341. message="task_failed_outside_cycle",
  342. data={
  343. "task_id": task["id"],
  344. "track_back": traceback.format_exc(),
  345. "error": str(e),
  346. },
  347. )
  348. update_task_queue_status(
  349. db_client=self.db_client,
  350. task_id=task["id"],
  351. task="extract",
  352. ori_status=const.PROCESSING_STATUS,
  353. new_status=const.FAIL_STATUS,
  354. )
  355. def extract_best_frame_with_gemini_ai(self):
  356. # roll back lock tasks
  357. roll_back_lock_tasks_count = self._roll_back_lock_tasks(task="extract")
  358. log(
  359. task=const.TASK_NAME,
  360. function="extract_best_frame_with_gemini_ai",
  361. message=f"roll_back_lock_tasks_count: {roll_back_lock_tasks_count}",
  362. )
  363. # do extract frame task
  364. task_list = self.get_extract_task_list()
  365. for task in tqdm(task_list, desc="extract_best_frame_with_gemini_ai"):
  366. self.extract_each_video(task=task)
  367. def get_each_cover(self, task: dict) -> None:
  368. lock_status = self._lock_task(task_id=task["id"], task_name="get_cover")
  369. if not lock_status:
  370. return None
  371. time_str = normalize_time_str(task["best_frame_time_ms"])
  372. if time_str:
  373. response = get_video_cover(
  374. video_oss_path=task["video_oss_path"], time_millisecond_str=time_str
  375. )
  376. log(
  377. task=const.TASK_NAME,
  378. function="extract_cover_with_ffmpeg",
  379. message="get_video_cover_with_ffmpeg",
  380. data={
  381. "task_id": task["id"],
  382. "video_oss_path": task["video_oss_path"],
  383. "time_millisecond_str": time_str,
  384. "response": response,
  385. },
  386. )
  387. if response["success"] and response["data"]:
  388. cover_oss_path = response["data"]
  389. self.set_cover_result(task_id=task["id"], cover_oss_path=cover_oss_path)
  390. else:
  391. update_task_queue_status(
  392. db_client=self.db_client,
  393. task_id=task["id"],
  394. task="get_cover",
  395. ori_status=const.PROCESSING_STATUS,
  396. new_status=const.FAIL_STATUS,
  397. )
  398. else:
  399. log(
  400. task=const.TASK_NAME,
  401. function="extract_cover_with_ffmpeg",
  402. message="time_str format is not correct",
  403. data={
  404. "task_id": task["id"],
  405. "video_oss_path": task["video_oss_path"],
  406. "time_millisecond_str": time_str,
  407. },
  408. )
  409. update_task_queue_status(
  410. db_client=self.db_client,
  411. task_id=task["id"],
  412. task="get_cover",
  413. ori_status=const.PROCESSING_STATUS,
  414. new_status=const.FAIL_STATUS,
  415. )
  416. def get_cover_with_best_frame(self):
  417. """
  418. get cover with best frame
  419. """
  420. # roll back lock tasks
  421. roll_back_lock_tasks_count = self._roll_back_lock_tasks(task="get_cover")
  422. log(
  423. task=const.TASK_NAME,
  424. function="extract_cover_with_ffmpeg",
  425. message=f"roll_back_lock_tasks_count: {roll_back_lock_tasks_count}",
  426. )
  427. # get task list
  428. task_list = self.get_cover_task_list()
  429. for task in tqdm(task_list, desc="extract_cover_with_ffmpeg"):
  430. self.get_each_cover(task=task)