extract_video_best_frame.py 21 KB


  1. """
  2. @author luojunhui
  3. @desc find best frame from each video
  4. """
  5. import os
  6. import datetime
  7. import traceback
  8. from tqdm import tqdm
  9. from pymysql.cursors import DictCursor
  10. from applications import log
  11. from applications.api import GoogleAIAPI
  12. from applications.const import GoogleVideoUnderstandTaskConst
  13. from applications.db import DatabaseConnector
  14. from config import long_articles_config
  15. from coldStartTasks.ai_pipeline.basic import download_file
  16. from coldStartTasks.ai_pipeline.basic import update_task_queue_status
  17. from coldStartTasks.ai_pipeline.basic import roll_back_lock_tasks
  18. from coldStartTasks.ai_pipeline.basic import extract_best_frame_prompt
  19. from coldStartTasks.ai_pipeline.basic import get_video_cover
  20. from coldStartTasks.ai_pipeline.basic import normalize_time_str
  21. const = GoogleVideoUnderstandTaskConst()
  22. google_ai = GoogleAIAPI(api_key='AIzaSyBAFFL6yHa4kUK-YcuAw8tbiSHQB6oJG34')
  23. class ExtractVideoBestFrame:
  24. """
  25. extract video best frame from each video by GeminiAI
  26. """
  27. def __init__(self):
  28. self.db_client = DatabaseConnector(db_config=long_articles_config)
  29. self.db_client.connect()
  30. self.api_status = True
  31. def _roll_back_lock_tasks(self, task: str) -> int:
  32. return roll_back_lock_tasks(
  33. db_client=self.db_client,
  34. task=task,
  35. init_status=const.INIT_STATUS,
  36. processing_status=const.PROCESSING_STATUS,
  37. max_process_time=const.MAX_PROCESSING_TIME,
  38. )
  39. def _lock_task(self, task_id: int, task_name) -> int:
  40. return update_task_queue_status(
  41. db_client=self.db_client,
  42. task_id=task_id,
  43. task=task_name,
  44. ori_status=const.INIT_STATUS,
  45. new_status=const.PROCESSING_STATUS,
  46. )
  47. def get_upload_task_list(self, task_num: int = const.POOL_SIZE) -> list[dict]:
  48. """
  49. get upload task list
  50. """
  51. fetch_query = f"""
  52. select id, video_oss_path from {const.TABLE_NAME}
  53. where upload_status = {const.INIT_STATUS}
  54. order by priority desc
  55. limit {task_num};
  56. """
  57. upload_task_list = self.db_client.fetch(
  58. query=fetch_query, cursor_type=DictCursor
  59. )
  60. return upload_task_list
  61. def get_extract_task_list(self, task_num: int = const.POOL_SIZE) -> list[dict]:
  62. """
  63. get extract task list
  64. """
  65. fetch_query = f"""
  66. select id, file_name from {const.TABLE_NAME}
  67. where upload_status = {const.SUCCESS_STATUS} and extract_status = {const.INIT_STATUS}
  68. order by file_expire_time
  69. limit {task_num};
  70. """
  71. extract_task_list = self.db_client.fetch(
  72. query=fetch_query, cursor_type=DictCursor
  73. )
  74. return extract_task_list
  75. def get_cover_task_list(self) -> list[dict]:
  76. """
  77. get cover task list
  78. """
  79. fetch_query = f"""
  80. select id, video_oss_path, best_frame_time_ms from {const.TABLE_NAME}
  81. where extract_status = {const.SUCCESS_STATUS} and get_cover_status = {const.INIT_STATUS};
  82. """
  83. extract_task_list = self.db_client.fetch(
  84. query=fetch_query, cursor_type=DictCursor
  85. )
  86. return extract_task_list
  87. def get_processing_task_pool_size(self) -> int:
  88. """
  89. get processing task pool size
  90. """
  91. fetch_query = f"""
  92. select count(1) as pool_size from {const.TABLE_NAME}
  93. where upload_status = {const.SUCCESS_STATUS} and file_state = 'PROCESSING' and extract_status = {const.INIT_STATUS};
  94. """
  95. fetch_response = self.db_client.fetch(query=fetch_query, cursor_type=DictCursor)
  96. processing_task_pool_size = (
  97. fetch_response[0]["pool_size"] if fetch_response else 0
  98. )
  99. return processing_task_pool_size
  100. def set_upload_result(
  101. self, task_id: int, file_name: str, file_state: str, file_expire_time: str
  102. ) -> int:
  103. update_query = f"""
  104. update {const.TABLE_NAME}
  105. set upload_status = %s, upload_status_ts = %s,
  106. file_name = %s, file_state = %s, file_expire_time = %s
  107. where id = %s and upload_status = %s;
  108. """
  109. update_rows = self.db_client.save(
  110. query=update_query,
  111. params=(
  112. const.SUCCESS_STATUS,
  113. datetime.datetime.now(),
  114. file_name,
  115. file_state,
  116. file_expire_time,
  117. task_id,
  118. const.PROCESSING_STATUS,
  119. ),
  120. )
  121. return update_rows
  122. def set_extract_result(
  123. self, task_id: int, file_state: str, best_frame_time_ms: str
  124. ) -> int:
  125. update_query = f"""
  126. update {const.TABLE_NAME}
  127. set extract_status = %s, extract_status_ts = %s,
  128. file_state = %s, best_frame_time_ms = %s
  129. where id = %s and extract_status = %s;
  130. """
  131. update_rows = self.db_client.save(
  132. query=update_query,
  133. params=(
  134. const.SUCCESS_STATUS,
  135. datetime.datetime.now(),
  136. file_state,
  137. best_frame_time_ms,
  138. task_id,
  139. const.PROCESSING_STATUS,
  140. ),
  141. )
  142. return update_rows
  143. def set_cover_result(self, task_id: int, cover_oss_path: str) -> int:
  144. update_query = f"""
  145. update {const.TABLE_NAME}
  146. set cover_oss_path = %s, get_cover_status = %s, get_cover_status_ts = %s
  147. where id = %s and get_cover_status = %s;
  148. """
  149. update_rows = self.db_client.save(
  150. query=update_query,
  151. params=(
  152. cover_oss_path,
  153. const.SUCCESS_STATUS,
  154. datetime.datetime.now(),
  155. task_id,
  156. const.PROCESSING_STATUS,
  157. ),
  158. )
  159. return update_rows
  160. def upload_each_video(self, task: dict) -> None:
  161. lock_status = self._lock_task(task_id=task["id"], task_name="upload")
  162. if not lock_status:
  163. return None
  164. try:
  165. file_path = download_file(task["id"], task["video_oss_path"])
  166. upload_response = google_ai.upload_file(file_path)
  167. if upload_response:
  168. file_name, file_state, expire_time = upload_response
  169. self.set_upload_result(
  170. task_id=task["id"],
  171. file_name=file_name,
  172. file_state=file_state,
  173. file_expire_time=expire_time,
  174. )
  175. return None
  176. else:
  177. # set status as fail
  178. update_task_queue_status(
  179. db_client=self.db_client,
  180. task_id=task["id"],
  181. task="upload",
  182. ori_status=const.PROCESSING_STATUS,
  183. new_status=const.FAIL_STATUS,
  184. )
  185. return None
  186. except Exception as e:
  187. log(
  188. task=const.TASK_NAME,
  189. function="upload_video_to_gemini_ai",
  190. message="task_failed",
  191. data={
  192. "task_id": task["id"],
  193. "track_back": traceback.format_exc(),
  194. "error": str(e),
  195. },
  196. )
  197. update_task_queue_status(
  198. db_client=self.db_client,
  199. task_id=task["id"],
  200. task="upload",
  201. ori_status=const.PROCESSING_STATUS,
  202. new_status=const.FAIL_STATUS,
  203. )
  204. return None
  205. def upload_video_to_gemini_ai(
  206. self, max_processing_pool_size: int = const.POOL_SIZE
  207. ) -> None:
  208. if not self.api_status:
  209. log(
  210. task=const.TASK_NAME,
  211. function="upload_video_to_gemini_ai",
  212. message="api_status is False, skip this task",
  213. )
  214. return
  215. # upload video to gemini ai
  216. roll_back_lock_tasks_count = self._roll_back_lock_tasks(task="upload")
  217. log(
  218. task=const.TASK_NAME,
  219. function="upload_video_to_gemini_ai",
  220. message=f"roll_back_lock_tasks_count: {roll_back_lock_tasks_count}",
  221. )
  222. processing_task_num = self.get_processing_task_pool_size()
  223. res_task_num = max_processing_pool_size - processing_task_num
  224. if res_task_num:
  225. upload_task_list = self.get_upload_task_list(task_num=res_task_num)
  226. for task in tqdm(upload_task_list, desc="upload_video_to_gemini_ai"):
  227. self.upload_each_video(task=task)
  228. else:
  229. log(
  230. task=const.TASK_NAME,
  231. function="upload_video_to_gemini_ai",
  232. message="reach pool size, no more space for task to upload",
  233. )
  234. def extract_each_video(self, task: dict) -> None:
  235. # lock task
  236. lock_status = self._lock_task(task_id=task["id"], task_name="extract")
  237. if not lock_status:
  238. return None
  239. file_name = task["file_name"]
  240. video_local_path = os.path.join(const.DIR_NAME, "{}.mp4".format(task["id"]))
  241. try:
  242. google_file = google_ai.get_google_file(file_name)
  243. state = google_file.state.name
  244. match state:
  245. case "PROCESSING":
  246. # google is still processing this video
  247. update_task_queue_status(
  248. db_client=self.db_client,
  249. task_id=task["id"],
  250. task="extract",
  251. ori_status=const.PROCESSING_STATUS,
  252. new_status=const.INIT_STATUS,
  253. )
  254. log(
  255. task=const.TASK_NAME,
  256. function="extract_best_frame_with_gemini_ai",
  257. message="google is still processing this video",
  258. data={
  259. "task_id": task["id"],
  260. "file_name": file_name,
  261. "state": state,
  262. },
  263. )
  264. case "FAILED":
  265. # google process this video failed
  266. update_query = f"""
  267. update {const.TABLE_NAME}
  268. set file_state = %s, extract_status = %s, extract_status_ts = %s
  269. where id = %s and extract_status = %s;
  270. """
  271. self.db_client.save(
  272. query=update_query,
  273. params=(
  274. "FAILED",
  275. const.FAIL_STATUS,
  276. datetime.datetime.now(),
  277. task["id"],
  278. const.PROCESSING_STATUS,
  279. ),
  280. )
  281. log(
  282. task=const.TASK_NAME,
  283. function="extract_best_frame_with_gemini_ai",
  284. message="google process this video failed",
  285. data={
  286. "task_id": task["id"],
  287. "file_name": file_name,
  288. "state": state,
  289. },
  290. )
  291. case "ACTIVE":
  292. # video process successfully
  293. try:
  294. response = google_ai.fetch_info_from_google_ai(
  295. prompt=extract_best_frame_prompt(),
  296. video_file=google_file,
  297. )
  298. code = response['code']
  299. match code:
  300. case 200:
  301. best_frame_time_ms = response['text']
  302. if best_frame_time_ms:
  303. self.set_extract_result(
  304. task_id=task["id"],
  305. file_state="ACTIVE",
  306. best_frame_time_ms=best_frame_time_ms.strip(),
  307. )
  308. else:
  309. update_task_queue_status(
  310. db_client=self.db_client,
  311. task_id=task["id"],
  312. task="extract",
  313. ori_status=const.PROCESSING_STATUS,
  314. new_status=const.FAIL_STATUS,
  315. )
  316. # delete local file and google file
  317. if os.path.exists(video_local_path):
  318. os.remove(video_local_path)
  319. google_ai.delete_video(file_name)
  320. log(
  321. task=const.TASK_NAME,
  322. function="extract_best_frame_with_gemini_ai",
  323. message="video process successfully",
  324. data={
  325. "task_id": task["id"],
  326. "file_name": file_name,
  327. "state": state,
  328. "best_frame_time_ms": best_frame_time_ms,
  329. },
  330. )
  331. case 429:
  332. # resource exhausted
  333. log(
  334. task=const.TASK_NAME,
  335. function="extract_best_frame_with_gemini_ai",
  336. message="resource exhausted",
  337. data={
  338. "task_id": task["id"],
  339. "file_name": file_name,
  340. "state": state,
  341. "response": response
  342. },
  343. )
  344. update_task_queue_status(
  345. db_client=self.db_client,
  346. task_id=task["id"],
  347. task="extract",
  348. ori_status=const.PROCESSING_STATUS,
  349. new_status=const.INIT_STATUS
  350. )
  351. return
  352. case 500:
  353. # other error
  354. log(
  355. task=const.TASK_NAME,
  356. function="extract_best_frame_with_gemini_ai",
  357. message="other error",
  358. data={
  359. "task_id": task["id"],
  360. "file_name": file_name,
  361. "state": state,
  362. "response": response,
  363. },
  364. )
  365. update_task_queue_status(
  366. db_client=self.db_client,
  367. task_id=task["id"],
  368. task="extract",
  369. ori_status=const.PROCESSING_STATUS,
  370. new_status=const.FAIL_STATUS
  371. )
  372. except Exception as e:
  373. log(
  374. task=const.TASK_NAME,
  375. function="extract_best_frame_with_gemini_ai",
  376. message="task_failed_inside_cycle",
  377. data={
  378. "task_id": task["id"],
  379. "track_back": traceback.format_exc(),
  380. "error": str(e),
  381. },
  382. )
  383. update_task_queue_status(
  384. db_client=self.db_client,
  385. task_id=task["id"],
  386. task="extract",
  387. ori_status=const.PROCESSING_STATUS,
  388. new_status=const.FAIL_STATUS,
  389. )
  390. except Exception as e:
  391. log(
  392. task=const.TASK_NAME,
  393. function="extract_best_frame_with_gemini_ai",
  394. message="task_failed_outside_cycle",
  395. data={
  396. "task_id": task["id"],
  397. "track_back": traceback.format_exc(),
  398. "error": str(e),
  399. },
  400. )
  401. update_task_queue_status(
  402. db_client=self.db_client,
  403. task_id=task["id"],
  404. task="extract",
  405. ori_status=const.PROCESSING_STATUS,
  406. new_status=const.FAIL_STATUS,
  407. )
  408. def extract_best_frame_with_gemini_ai(self):
  409. if not self.api_status:
  410. log(
  411. task=const.TASK_NAME,
  412. function="extract_best_frame_with_gemini_ai",
  413. message="api_status is False, skip this task",
  414. )
  415. return
  416. # roll back lock tasks
  417. roll_back_lock_tasks_count = self._roll_back_lock_tasks(task="extract")
  418. log(
  419. task=const.TASK_NAME,
  420. function="extract_best_frame_with_gemini_ai",
  421. message=f"roll_back_lock_tasks_count: {roll_back_lock_tasks_count}",
  422. )
  423. # do extract frame task
  424. task_list = self.get_extract_task_list()
  425. for task in tqdm(task_list, desc="extract_best_frame_with_gemini_ai"):
  426. self.extract_each_video(task=task)
  427. def get_each_cover(self, task: dict) -> None:
  428. lock_status = self._lock_task(task_id=task["id"], task_name="get_cover")
  429. if not lock_status:
  430. return None
  431. time_str = normalize_time_str(task["best_frame_time_ms"])
  432. if time_str:
  433. response = get_video_cover(
  434. video_oss_path=task["video_oss_path"], time_millisecond_str=time_str
  435. )
  436. log(
  437. task=const.TASK_NAME,
  438. function="extract_cover_with_ffmpeg",
  439. message="get_video_cover_with_ffmpeg",
  440. data={
  441. "task_id": task["id"],
  442. "video_oss_path": task["video_oss_path"],
  443. "time_millisecond_str": time_str,
  444. "response": response,
  445. },
  446. )
  447. if response["success"] and response["data"]:
  448. cover_oss_path = response["data"]
  449. self.set_cover_result(task_id=task["id"], cover_oss_path=cover_oss_path)
  450. else:
  451. update_task_queue_status(
  452. db_client=self.db_client,
  453. task_id=task["id"],
  454. task="get_cover",
  455. ori_status=const.PROCESSING_STATUS,
  456. new_status=const.FAIL_STATUS,
  457. )
  458. else:
  459. log(
  460. task=const.TASK_NAME,
  461. function="extract_cover_with_ffmpeg",
  462. message="time_str format is not correct",
  463. data={
  464. "task_id": task["id"],
  465. "video_oss_path": task["video_oss_path"],
  466. "time_millisecond_str": time_str,
  467. },
  468. )
  469. update_task_queue_status(
  470. db_client=self.db_client,
  471. task_id=task["id"],
  472. task="get_cover",
  473. ori_status=const.PROCESSING_STATUS,
  474. new_status=const.FAIL_STATUS,
  475. )
  476. def get_cover_with_best_frame(self):
  477. """
  478. get cover with best frame
  479. """
  480. if not self.api_status:
  481. log(
  482. task=const.TASK_NAME,
  483. function="extract_cover_with_ffmpeg",
  484. message="api_status is False, skip this task",
  485. )
  486. return
  487. # roll back lock tasks
  488. roll_back_lock_tasks_count = self._roll_back_lock_tasks(task="get_cover")
  489. log(
  490. task=const.TASK_NAME,
  491. function="extract_cover_with_ffmpeg",
  492. message=f"roll_back_lock_tasks_count: {roll_back_lock_tasks_count}",
  493. )
  494. # get task list
  495. task_list = self.get_cover_task_list()
  496. for task in tqdm(task_list, desc="extract_cover_with_ffmpeg"):
  497. self.get_each_cover(task=task)