extract_video_best_frame.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. """
  2. @author luojunhui
  3. @desc find best frame from each video
  4. """
  5. import os
  6. import datetime
  7. import traceback
  8. from tqdm import tqdm
  9. from pymysql.cursors import DictCursor
  10. from applications import log
  11. from applications.api import GoogleAIAPI
  12. from applications.const import GoogleVideoUnderstandTaskConst
  13. from applications.db import DatabaseConnector
  14. from config import long_articles_config
  15. from coldStartTasks.ai_pipeline.basic import download_file
  16. from coldStartTasks.ai_pipeline.basic import update_task_queue_status
  17. from coldStartTasks.ai_pipeline.basic import roll_back_lock_tasks
  18. from coldStartTasks.ai_pipeline.basic import extract_best_frame_prompt
  19. from coldStartTasks.ai_pipeline.basic import get_video_cover
  20. from coldStartTasks.ai_pipeline.basic import normalize_time_str
  21. const = GoogleVideoUnderstandTaskConst()
  22. google_ai = GoogleAIAPI()
  23. class ExtractVideoBestFrame:
  24. """
  25. extract video best frame from each video by GeminiAI
  26. """
  27. def __init__(self):
  28. self.db_client = DatabaseConnector(db_config=long_articles_config)
  29. self.db_client.connect()
  30. def get_upload_task_list(self, task_num: int = const.POOL_SIZE) -> list[dict]:
  31. """
  32. get upload task list
  33. """
  34. fetch_query = f"""
  35. select id, video_oss_path from {const.TABLE_NAME}
  36. where upload_status = {const.INIT_STATUS}
  37. order by priority desc
  38. limit {task_num};
  39. """
  40. upload_task_list = self.db_client.fetch(
  41. query=fetch_query, cursor_type=DictCursor
  42. )
  43. return upload_task_list
  44. def get_extract_task_list(self, task_num: int = const.POOL_SIZE) -> list[dict]:
  45. """
  46. get extract task list
  47. """
  48. fetch_query = f"""
  49. select id, file_name from {const.TABLE_NAME}
  50. where upload_status = {const.SUCCESS_STATUS} and extract_status = {const.INIT_STATUS}
  51. order by file_expire_time
  52. limit {task_num};
  53. """
  54. extract_task_list = self.db_client.fetch(
  55. query=fetch_query, cursor_type=DictCursor
  56. )
  57. return extract_task_list
  58. def get_cover_task_list(self) -> list[dict]:
  59. """
  60. get cover task list
  61. """
  62. fetch_query = f"""
  63. select id, video_oss_path, best_frame_time_ms from {const.TABLE_NAME}
  64. where extract_status = {const.SUCCESS_STATUS} and get_cover_status = {const.INIT_STATUS};
  65. """
  66. extract_task_list = self.db_client.fetch(
  67. query=fetch_query, cursor_type=DictCursor
  68. )
  69. return extract_task_list
  70. def get_processing_task_pool_size(self) -> int:
  71. """
  72. get processing task pool size
  73. """
  74. fetch_query = f"""
  75. select count(1) as pool_size from {const.TABLE_NAME}
  76. where upload_status = {const.SUCCESS_STATUS} and file_state = 'PROCESSING' and extract_status = {const.INIT_STATUS};
  77. """
  78. fetch_response = self.db_client.fetch(query=fetch_query, cursor_type=DictCursor)
  79. processing_task_pool_size = (
  80. fetch_response[0]["pool_size"] if fetch_response else 0
  81. )
  82. return processing_task_pool_size
  83. def set_upload_result(
  84. self, task_id: int, file_name: str, file_state: str, file_expire_time: str
  85. ) -> int:
  86. update_query = f"""
  87. update {const.TABLE_NAME}
  88. set upload_status = %s, upload_status_ts = %s,
  89. file_name = %s, file_state = %s, file_expire_time = %s
  90. where id = %s and upload_status = %s;
  91. """
  92. update_rows = self.db_client.save(
  93. query=update_query,
  94. params=(
  95. const.SUCCESS_STATUS,
  96. datetime.datetime.now(),
  97. file_name,
  98. file_state,
  99. file_expire_time,
  100. task_id,
  101. const.PROCESSING_STATUS,
  102. ),
  103. )
  104. return update_rows
  105. def set_extract_result(
  106. self, task_id: int, file_state: str, best_frame_tims_ms: str
  107. ) -> int:
  108. update_query = f"""
  109. update {const.TABLE_NAME}
  110. set extract_status = %s, extract_status_ts = %s,
  111. file_state = %s, best_frame_time_ms = %s
  112. where id = %s and extract_status = %s;
  113. """
  114. update_rows = self.db_client.save(
  115. query=update_query,
  116. params=(
  117. const.SUCCESS_STATUS,
  118. datetime.datetime.now(),
  119. file_state,
  120. best_frame_tims_ms,
  121. task_id,
  122. const.PROCESSING_STATUS,
  123. ),
  124. )
  125. return update_rows
  126. def set_cover_result(self, task_id: int, cover_oss_path: str) -> int:
  127. update_query = f"""
  128. update {const.TABLE_NAME}
  129. set cover_oss_path = %s, get_cover_status = %s, get_cover_status_ts = %s
  130. where id = %s and get_cover_status = %s;
  131. """
  132. update_rows = self.db_client.save(
  133. query=update_query,
  134. params=(
  135. cover_oss_path,
  136. const.SUCCESS_STATUS,
  137. datetime.datetime.now(),
  138. task_id,
  139. const.PROCESSING_STATUS,
  140. )
  141. )
  142. return update_rows
  143. def upload_video_to_gemini_ai(
  144. self, max_processing_pool_size: int = const.POOL_SIZE
  145. ) -> None:
  146. # upload video to gemini ai
  147. roll_back_lock_tasks_count = roll_back_lock_tasks(
  148. db_client=self.db_client,
  149. task="upload",
  150. init_status=const.INIT_STATUS,
  151. processing_status=const.PROCESSING_STATUS,
  152. max_process_time=const.MAX_PROCESSING_TIME,
  153. )
  154. log(
  155. task=const.TASK_NAME,
  156. function="upload_video_to_gemini_ai",
  157. message=f"roll_back_lock_tasks_count: {roll_back_lock_tasks_count}",
  158. )
  159. processing_task_num = self.get_processing_task_pool_size()
  160. res_task_num = max_processing_pool_size - processing_task_num
  161. if res_task_num:
  162. upload_task_list = self.get_upload_task_list(task_num=res_task_num)
  163. for task in tqdm(upload_task_list, desc="upload_video_to_gemini_ai"):
  164. lock_status = update_task_queue_status(
  165. db_client=self.db_client,
  166. task_id=task["id"],
  167. task="upload",
  168. ori_status=const.INIT_STATUS,
  169. new_status=const.PROCESSING_STATUS,
  170. )
  171. if not lock_status:
  172. continue
  173. try:
  174. file_path = download_file(task["id"], task["video_oss_path"])
  175. upload_response = google_ai.upload_file(file_path)
  176. if upload_response:
  177. file_name, file_state, expire_time = upload_response
  178. self.set_upload_result(
  179. task_id=task["id"],
  180. file_name=file_name,
  181. file_state=file_state,
  182. file_expire_time=expire_time,
  183. )
  184. else:
  185. # set status as fail
  186. update_task_queue_status(
  187. db_client=self.db_client,
  188. task_id=task["id"],
  189. task="upload",
  190. ori_status=const.PROCESSING_STATUS,
  191. new_status=const.FAIL_STATUS,
  192. )
  193. except Exception as e:
  194. log(
  195. task=const.TASK_NAME,
  196. function="upload_video_to_gemini_ai",
  197. message="task_failed",
  198. data={
  199. "task_id": task["id"],
  200. "track_back": traceback.format_exc(),
  201. "error": str(e),
  202. }
  203. )
  204. update_task_queue_status(
  205. db_client=self.db_client,
  206. task_id=task["id"],
  207. task="upload",
  208. ori_status=const.PROCESSING_STATUS,
  209. new_status=const.FAIL_STATUS,
  210. )
  211. continue
  212. else:
  213. log(
  214. task=const.TASK_NAME,
  215. function="upload_video_to_gemini_ai",
  216. message="reach pool size, no more space for task to upload",
  217. )
  218. def extract_best_frame_with_gemini_ai(self):
  219. # roll back lock tasks
  220. roll_back_lock_tasks_count = roll_back_lock_tasks(
  221. db_client=self.db_client,
  222. task="extract",
  223. init_status=const.INIT_STATUS,
  224. processing_status=const.PROCESSING_STATUS,
  225. max_process_time=const.MAX_PROCESSING_TIME,
  226. )
  227. log(
  228. task=const.TASK_NAME,
  229. function="extract_best_frame_with_gemini_ai",
  230. message=f"roll_back_lock_tasks_count: {roll_back_lock_tasks_count}",
  231. )
  232. # do extract frame task
  233. task_list = self.get_extract_task_list()
  234. for task in tqdm(task_list, desc="extract_best_frame_with_gemini_ai"):
  235. # lock task
  236. lock_status = update_task_queue_status(
  237. db_client=self.db_client,
  238. task_id=task["id"],
  239. task="extract",
  240. ori_status=const.INIT_STATUS,
  241. new_status=const.PROCESSING_STATUS,
  242. )
  243. if not lock_status:
  244. continue
  245. file_name = task["file_name"]
  246. video_local_path = os.path.join(const.DIR_NAME, "{}.mp4".format(task["id"]))
  247. try:
  248. google_file = google_ai.get_google_file(file_name)
  249. state = google_file.state.name
  250. match state:
  251. case "PROCESSING":
  252. # google is still processing this video
  253. update_task_queue_status(
  254. db_client=self.db_client,
  255. task_id=task["id"],
  256. task="extract",
  257. ori_status=const.PROCESSING_STATUS,
  258. new_status=const.INIT_STATUS,
  259. )
  260. log(
  261. task=const.TASK_NAME,
  262. function="extract_best_frame_with_gemini_ai",
  263. message="google is still processing this video",
  264. data={
  265. "task_id": task["id"],
  266. "file_name": file_name,
  267. "state": state
  268. }
  269. )
  270. case "FAILED":
  271. # google process this video failed
  272. update_query = f"""
  273. update {const.TABLE_NAME}
  274. set file_state = %s, extract_status = %s, extract_status_ts = %s
  275. where id = %s and extract_status = %s;
  276. """
  277. self.db_client.save(
  278. query=update_query,
  279. params=(
  280. "FAILED",
  281. const.FAIL_STATUS,
  282. datetime.datetime.now(),
  283. task["id"],
  284. const.PROCESSING_STATUS,
  285. ),
  286. )
  287. log(
  288. task=const.TASK_NAME,
  289. function="extract_best_frame_with_gemini_ai",
  290. message="google process this video failed",
  291. data={
  292. "task_id": task["id"],
  293. "file_name": file_name,
  294. "state": state
  295. }
  296. )
  297. case "ACTIVE":
  298. # video process successfully
  299. try:
  300. best_frame_tims_ms = google_ai.fetch_info_from_google_ai(
  301. prompt=extract_best_frame_prompt(),
  302. video_file=google_file,
  303. )
  304. if best_frame_tims_ms:
  305. self.set_extract_result(
  306. task_id=task["id"],
  307. file_state="ACTIVE",
  308. best_frame_tims_ms=best_frame_tims_ms.strip(),
  309. )
  310. else:
  311. update_task_queue_status(
  312. db_client=self.db_client,
  313. task_id=task["id"],
  314. task="extract",
  315. ori_status=const.PROCESSING_STATUS,
  316. new_status=const.FAIL_STATUS,
  317. )
  318. # delete local file and google file
  319. if os.path.exists(video_local_path):
  320. os.remove(video_local_path)
  321. google_ai.delete_video(file_name)
  322. log(
  323. task=const.TASK_NAME,
  324. function="extract_best_frame_with_gemini_ai",
  325. message="video process successfully",
  326. data={
  327. "task_id": task["id"],
  328. "file_name": file_name,
  329. "state": state,
  330. "best_frame_tims_ms": best_frame_tims_ms
  331. }
  332. )
  333. except Exception as e:
  334. log(
  335. task=const.TASK_NAME,
  336. function="extract_best_frame_with_gemini_ai",
  337. message="task_failed_inside_cycle",
  338. data={
  339. "task_id": task["id"],
  340. "track_back": traceback.format_exc(),
  341. "error": str(e),
  342. }
  343. )
  344. update_task_queue_status(
  345. db_client=self.db_client,
  346. task_id=task["id"],
  347. task="extract",
  348. ori_status=const.PROCESSING_STATUS,
  349. new_status=const.FAIL_STATUS,
  350. )
  351. except Exception as e:
  352. log(
  353. task=const.TASK_NAME,
  354. function="extract_best_frame_with_gemini_ai",
  355. message="task_failed_outside_cycle",
  356. data={
  357. "task_id": task["id"],
  358. "track_back": traceback.format_exc(),
  359. "error": str(e),
  360. }
  361. )
  362. update_task_queue_status(
  363. db_client=self.db_client,
  364. task_id=task["id"],
  365. task="extract",
  366. ori_status=const.PROCESSING_STATUS,
  367. new_status=const.FAIL_STATUS,
  368. )
  369. def get_cover_with_best_frame(self):
  370. """
  371. get cover with best frame
  372. """
  373. # roll back lock tasks
  374. roll_back_lock_tasks_count = roll_back_lock_tasks(
  375. db_client=self.db_client,
  376. task="get_cover",
  377. init_status=const.INIT_STATUS,
  378. processing_status=const.PROCESSING_STATUS,
  379. max_process_time=const.MAX_PROCESSING_TIME,
  380. )
  381. log(
  382. task=const.TASK_NAME,
  383. function="extract_cover_with_ffmpeg",
  384. message=f"roll_back_lock_tasks_count: {roll_back_lock_tasks_count}",
  385. )
  386. # get task list
  387. task_list = self.get_cover_task_list()
  388. for task in tqdm(task_list, desc="extract_cover_with_ffmpeg"):
  389. # lock task
  390. lock_status = update_task_queue_status(
  391. db_client=self.db_client,
  392. task_id=task["id"],
  393. task="get_cover",
  394. ori_status=const.INIT_STATUS,
  395. new_status=const.PROCESSING_STATUS,
  396. )
  397. if not lock_status:
  398. continue
  399. time_str = normalize_time_str(task["best_frame_time_ms"])
  400. if time_str:
  401. response = get_video_cover(
  402. video_oss_path=task["video_oss_path"], time_millisecond_str=time_str
  403. )
  404. log(
  405. task=const.TASK_NAME,
  406. function="extract_cover_with_ffmpeg",
  407. message="get_video_cover_with_ffmpeg",
  408. data={
  409. "task_id": task["id"],
  410. "video_oss_path": task["video_oss_path"],
  411. "time_millisecond_str": time_str,
  412. "response": response
  413. }
  414. )
  415. if response["success"] and response["data"]:
  416. cover_oss_path = response["data"]
  417. self.set_cover_result(task_id=task["id"], cover_oss_path=cover_oss_path)
  418. else:
  419. update_task_queue_status(
  420. db_client=self.db_client,
  421. task_id=task["id"],
  422. task="get_cover",
  423. ori_status=const.PROCESSING_STATUS,
  424. new_status=const.FAIL_STATUS,
  425. )
  426. else:
  427. log(
  428. task=const.TASK_NAME,
  429. function="extract_cover_with_ffmpeg",
  430. message="time_str format is not correct",
  431. data={
  432. "task_id": task["id"],
  433. "video_oss_path": task["video_oss_path"],
  434. "time_millisecond_str": time_str
  435. }
  436. )
  437. update_task_queue_status(
  438. db_client=self.db_client,
  439. task_id=task["id"],
  440. task="get_cover",
  441. ori_status=const.PROCESSING_STATUS,
  442. new_status=const.FAIL_STATUS,
  443. )