ai_tag_task.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import os
  2. import shutil
  3. import json
  4. import datetime
  5. import time
  6. import traceback
  7. import requests
  8. from threading import Timer
  9. from utils import data_check, get_feature_data, asr_validity_discrimination
  10. from whisper_asr import get_whisper_asr
  11. from gpt_tag import request_gpt
  12. from config import set_config
  13. from log import Log
  14. config_ = set_config()
  15. log_ = Log()
  16. features = ['videoid', 'title', 'video_path']
  17. def get_video_ai_tags(video_id, video_file, video_info):
  18. try:
  19. st_time = time.time()
  20. log_message = {
  21. 'videoId': int(video_id),
  22. }
  23. title = video_info.get('title')
  24. log_message['videoPath'] = video_info.get('video_path')
  25. log_message['title'] = video_info.get('title')
  26. # 1. asr
  27. asr_res_initial = get_whisper_asr(video=video_file)
  28. log_message['asrRes'] = asr_res_initial
  29. # 2. 判断asr识别的文本是否有效
  30. validity = asr_validity_discrimination(text=asr_res_initial)
  31. log_message['asrValidity'] = validity
  32. if validity is True:
  33. # 3. 对asr结果进行清洗
  34. asr_res = asr_res_initial.replace('\n', '')
  35. for stop_word in config_.STOP_WORDS:
  36. asr_res = asr_res.replace(stop_word, '')
  37. # token限制: 字数 <= 2500
  38. asr_res = asr_res[-2500:]
  39. # 4. gpt产出结果
  40. # 4.1 gpt产出summary, keywords,
  41. prompt1 = f"{config_.GPT_PROMPT['tags']['prompt6']}{asr_res.strip()}"
  42. log_message['gptPromptSummaryKeywords'] = prompt1
  43. gpt_res1 = request_gpt(prompt=prompt1)
  44. log_message['gptResSummaryKeywords'] = gpt_res1
  45. if gpt_res1 is not None:
  46. # 4.2 获取summary, keywords, title进行分类
  47. try:
  48. gpt_res1_json = json.loads(gpt_res1)
  49. summary = gpt_res1_json['summary']
  50. keywords = gpt_res1_json['keywords']
  51. log_message['summary'] = summary
  52. log_message['keywords'] = keywords
  53. prompt2_param = f"标题:{title}\n概况:{summary}\n关键词:{keywords}"
  54. prompt2 = f"{config_.GPT_PROMPT['tags']['prompt7']}{prompt2_param}"
  55. log_message['gptPromptTag'] = prompt2
  56. gpt_res2 = request_gpt(prompt=prompt2)
  57. log_message['gptResTag'] = gpt_res2
  58. if gpt_res2 is not None:
  59. confidence_up_list = []
  60. try:
  61. for item in json.loads(gpt_res2):
  62. if item['confidence'] > 0.5 and item['category'] in config_.TAGS_NEW:
  63. confidence_up_list.append(f"AI标签-{item['category']}")
  64. except:
  65. pass
  66. confidence_up = ','.join(confidence_up_list)
  67. log_message['AITags'] = confidence_up
  68. # 5. 调用后端接口,结果传给后端
  69. if len(confidence_up) > 0:
  70. response = requests.post(url=config_.ADD_VIDEO_AI_TAGS_URL,
  71. json={'videoId': int(video_id), 'tagNames': confidence_up})
  72. res_data = json.loads(response.text)
  73. if res_data['code'] != 0:
  74. log_.error({'videoId': video_id, 'msg': 'add video ai tags fail!'})
  75. except:
  76. pass
  77. else:
  78. pass
  79. log_message['executeTime'] = (time.time() - st_time) * 1000
  80. log_.info(log_message)
  81. except Exception as e:
  82. log_.error(e)
  83. log_.error(traceback.format_exc())
  84. def ai_tags(project, table, dt):
  85. # 获取特征数据
  86. feature_df = get_feature_data(project=project, table=table, dt=dt, features=features)
  87. video_id_list = feature_df['videoid'].to_list()
  88. video_info = {}
  89. for video_id in video_id_list:
  90. title = feature_df[feature_df['videoid'] == video_id]['title'].values[0]
  91. video_path = feature_df[feature_df['videoid'] == video_id]['video_path'].values[0]
  92. if title is None:
  93. continue
  94. title = title.strip()
  95. if len(title) > 0:
  96. video_info[video_id] = {'title': title, 'video_path': video_path}
  97. # print(video_id, title)
  98. print(len(video_info))
  99. # 获取已下载视频
  100. download_folder = 'videos'
  101. retry = 0
  102. while retry < 3:
  103. video_folder_list = os.listdir(download_folder)
  104. if len(video_folder_list) < 2:
  105. retry += 1
  106. time.sleep(60)
  107. continue
  108. for video_id in video_folder_list:
  109. if video_id not in video_id_list:
  110. continue
  111. if video_info.get(video_id, None) is None:
  112. shutil.rmtree(os.path.join(download_folder, video_id))
  113. else:
  114. video_folder = os.path.join(download_folder, video_id)
  115. for filename in os.listdir(video_folder):
  116. video_type = filename.split('.')[-1]
  117. if video_type in ['mp4', 'm3u8']:
  118. video_file = os.path.join(video_folder, filename)
  119. get_video_ai_tags(video_id=video_id, video_file=video_file, video_info=video_info.get(video_id))
  120. # 将处理过的视频进行删除
  121. shutil.rmtree(os.path.join(download_folder, video_id))
  122. else:
  123. shutil.rmtree(os.path.join(download_folder, video_id))
  124. def timer_check():
  125. try:
  126. project = config_.DAILY_VIDEO['project']
  127. table = config_.DAILY_VIDEO['table']
  128. now_date = datetime.datetime.today()
  129. print(f"now_date: {datetime.datetime.strftime(now_date, '%Y%m%d')}")
  130. dt = datetime.datetime.strftime(now_date-datetime.timedelta(days=1), '%Y%m%d')
  131. # 查看数据是否已准备好
  132. data_count = data_check(project=project, table=table, dt=dt)
  133. if data_count > 0:
  134. print(f'videos count = {data_count}')
  135. # 数据准备好,进行视频下载
  136. ai_tags(project=project, table=table, dt=dt)
  137. print(f"videos ai tag finished!")
  138. else:
  139. # 数据没准备好,1分钟后重新检查
  140. Timer(60, timer_check).start()
  141. except Exception as e:
  142. print(f"视频ai打标签失败, exception: {e}, traceback: {traceback.format_exc()}")
  143. if __name__ == '__main__':
  144. timer_check()