ai_tag_task.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. import os
  2. import shutil
  3. import json
  4. import datetime
  5. import time
  6. import traceback
  7. import requests
  8. import multiprocessing
  9. from threading import Timer
  10. from utils import data_check, get_feature_data, asr_validity_discrimination
  11. from whisper_asr import get_whisper_asr
  12. from gpt_tag import request_gpt
  13. from config import set_config
  14. from log import Log
  15. config_ = set_config()
  16. log_ = Log()
  17. features = ['videoid', 'title', 'video_path']
  18. def get_video_ai_tags(video_id, asr_file, video_info):
  19. try:
  20. st_time = time.time()
  21. log_message = {
  22. 'videoId': int(video_id),
  23. }
  24. title = video_info.get('title')
  25. log_message['videoPath'] = video_info.get('video_path')
  26. log_message['title'] = video_info.get('title')
  27. # 1. 获取asr结果
  28. # asr_res_initial = get_whisper_asr(video=video_file)
  29. with open(asr_file, 'r', encoding='utf-8') as rf:
  30. asr_res_initial = rf.read()
  31. log_message['asrRes'] = asr_res_initial
  32. # 2. 判断asr识别的文本是否有效
  33. validity = asr_validity_discrimination(text=asr_res_initial)
  34. log_message['asrValidity'] = validity
  35. if validity is True:
  36. # 3. 对asr结果进行清洗
  37. asr_res = asr_res_initial.replace('\n', '')
  38. for stop_word in config_.STOP_WORDS:
  39. asr_res = asr_res.replace(stop_word, '')
  40. # token限制: 字数 <= 2500
  41. asr_res = asr_res[-2500:]
  42. # 4. gpt产出结果
  43. # 4.1 gpt产出summary, keywords,
  44. prompt1 = f"{config_.GPT_PROMPT['tags']['prompt6']}{asr_res.strip()}"
  45. log_message['gptPromptSummaryKeywords'] = prompt1
  46. gpt_res1 = request_gpt(prompt=prompt1)
  47. log_message['gptResSummaryKeywords'] = gpt_res1
  48. if gpt_res1 is not None:
  49. # 4.2 获取summary, keywords, title进行分类
  50. try:
  51. gpt_res1_json = json.loads(gpt_res1)
  52. summary = gpt_res1_json['summary']
  53. keywords = gpt_res1_json['keywords']
  54. log_message['summary'] = summary
  55. log_message['keywords'] = str(keywords)
  56. prompt2_param = f"标题:{title}\n概况:{summary}\n关键词:{keywords}"
  57. prompt2 = f"{config_.GPT_PROMPT['tags']['prompt7']}{prompt2_param}"
  58. log_message['gptPromptTag'] = prompt2
  59. gpt_res2 = request_gpt(prompt=prompt2)
  60. log_message['gptResTag'] = gpt_res2
  61. if gpt_res2 is not None:
  62. confidence_up_list = []
  63. try:
  64. for item in json.loads(gpt_res2):
  65. if item['confidence'] > 0.5 and item['category'] in config_.TAGS_NEW:
  66. confidence_up_list.append(f"AI标签-{item['category']}")
  67. except:
  68. pass
  69. confidence_up = ','.join(confidence_up_list)
  70. log_message['AITags'] = confidence_up
  71. # 5. 调用后端接口,结果传给后端
  72. if len(confidence_up) > 0:
  73. response = requests.post(url=config_.ADD_VIDEO_AI_TAGS_URL,
  74. json={'videoId': int(video_id), 'tagNames': confidence_up})
  75. res_data = json.loads(response.text)
  76. if res_data['code'] != 0:
  77. log_.error({'videoId': video_id, 'msg': 'add video ai tags fail!'})
  78. except:
  79. pass
  80. else:
  81. pass
  82. log_message['executeTime'] = (time.time() - st_time) * 1000
  83. log_.info(log_message)
  84. except Exception as e:
  85. log_.error(e)
  86. log_.error(traceback.format_exc())
  87. def process(video_id, video_info, download_folder):
  88. if video_info.get(video_id, None) is None:
  89. shutil.rmtree(os.path.join(download_folder, video_id))
  90. else:
  91. video_folder = os.path.join(download_folder, video_id)
  92. for filename in os.listdir(video_folder):
  93. video_type = filename.split('.')[-1]
  94. if video_type in ['mp4', 'm3u8']:
  95. video_file = os.path.join(video_folder, filename)
  96. get_video_ai_tags(video_id=video_id, video_file=video_file, video_info=video_info.get(video_id))
  97. # 将处理过的视频进行删除
  98. shutil.rmtree(os.path.join(download_folder, video_id))
  99. else:
  100. shutil.rmtree(os.path.join(download_folder, video_id))
  101. def ai_tags(project, table, dt):
  102. # 获取特征数据
  103. feature_df = get_feature_data(project=project, table=table, dt=dt, features=features)
  104. video_id_list = feature_df['videoid'].to_list()
  105. video_info = {}
  106. for video_id in video_id_list:
  107. title = feature_df[feature_df['videoid'] == video_id]['title'].values[0]
  108. video_path = feature_df[feature_df['videoid'] == video_id]['video_path'].values[0]
  109. if title is None:
  110. continue
  111. title = title.strip()
  112. if len(title) > 0:
  113. video_info[video_id] = {'title': title, 'video_path': video_path}
  114. # print(video_id, title)
  115. print(len(video_info))
  116. # 获取已下载视频
  117. download_folder = 'videos'
  118. retry = 0
  119. while retry < 3:
  120. video_folder_list = os.listdir(download_folder)
  121. if len(video_folder_list) < 2:
  122. retry += 1
  123. time.sleep(60)
  124. continue
  125. # pool = multiprocessing.Pool(processes=5)
  126. # for video_id in video_folder_list:
  127. # if video_id not in video_id_list:
  128. # continue
  129. # pool.apply_async(
  130. # func=process,
  131. # args=(video_id, video_info, download_folder)
  132. # )
  133. # pool.close()
  134. # pool.join()
  135. for video_id in video_folder_list:
  136. if video_id not in video_id_list:
  137. continue
  138. if video_info.get(video_id, None) is None:
  139. shutil.rmtree(os.path.join(download_folder, video_id))
  140. else:
  141. video_folder = os.path.join(download_folder, video_id)
  142. for filename in os.listdir(video_folder):
  143. video_type = filename.split('.')[-1]
  144. if video_type in ['mp4', 'm3u8']:
  145. video_file = os.path.join(video_folder, filename)
  146. get_video_ai_tags(video_id=video_id, video_file=video_file, video_info=video_info.get(video_id))
  147. # 将处理过的视频进行删除
  148. shutil.rmtree(os.path.join(download_folder, video_id))
  149. else:
  150. shutil.rmtree(os.path.join(download_folder, video_id))
  151. def ai_tags_new(project, table, dt):
  152. # 获取特征数据
  153. feature_df = get_feature_data(project=project, table=table, dt=dt, features=features)
  154. video_id_list = feature_df['videoid'].to_list()
  155. video_info = {}
  156. for video_id in video_id_list:
  157. title = feature_df[feature_df['videoid'] == video_id]['title'].values[0]
  158. video_path = feature_df[feature_df['videoid'] == video_id]['video_path'].values[0]
  159. if title is None:
  160. continue
  161. title = title.strip()
  162. if len(title) > 0:
  163. video_info[video_id] = {'title': title, 'video_path': video_path}
  164. # print(video_id, title)
  165. print(len(video_info))
  166. # 获取已asr识别的视频
  167. asr_folder = 'asr_res'
  168. retry = 0
  169. while retry < 30:
  170. asr_file_list = os.listdir(asr_folder)
  171. if len(asr_file_list) < 1:
  172. retry += 1
  173. time.sleep(60)
  174. continue
  175. retry = 0
  176. for asr_filename in asr_file_list:
  177. video_id = asr_filename[:-4]
  178. if video_id not in video_id_list:
  179. continue
  180. asr_file = os.path.join(asr_folder, asr_filename)
  181. if video_info.get(video_id, None) is None:
  182. os.remove(asr_file)
  183. else:
  184. get_video_ai_tags(video_id=video_id, asr_file=asr_file, video_info=video_info.get(video_id))
  185. os.remove(asr_file)
  186. def timer_check():
  187. try:
  188. project = config_.DAILY_VIDEO['project']
  189. table = config_.DAILY_VIDEO['table']
  190. now_date = datetime.datetime.today()
  191. print(f"now_date: {datetime.datetime.strftime(now_date, '%Y%m%d')}")
  192. dt = datetime.datetime.strftime(now_date-datetime.timedelta(days=1), '%Y%m%d')
  193. # 查看数据是否已准备好
  194. data_count = data_check(project=project, table=table, dt=dt)
  195. if data_count > 0:
  196. print(f'videos count = {data_count}')
  197. asr_folder = 'asr_res'
  198. if not os.path.exists(asr_folder):
  199. # 1分钟后重新检查
  200. Timer(60, timer_check).start()
  201. else:
  202. # 数据准备好,进行aiTag
  203. ai_tags_new(project=project, table=table, dt=dt)
  204. print(f"videos ai tag finished!")
  205. else:
  206. # 数据没准备好,1分钟后重新检查
  207. Timer(60, timer_check).start()
  208. except Exception as e:
  209. print(f"视频ai打标签失败, exception: {e}, traceback: {traceback.format_exc()}")
  210. if __name__ == '__main__':
  211. timer_check()