ai_tag_task.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. import os
  2. import shutil
  3. import json
  4. import datetime
  5. import time
  6. import traceback
  7. import requests
  8. import multiprocessing
  9. import ODPSQueryUtil
  10. from threading import Timer
  11. from utils import data_check, get_feature_data, asr_validity_discrimination
  12. from whisper_asr import get_whisper_asr
  13. from gpt_tag import request_gpt
  14. from config import set_config
  15. from log import Log
  16. from ReadXlsxFile import getVideoInfoInXlxs
  17. import mysql_connect
  18. config_ = set_config()
  19. log_ = Log()
  20. features = ['videoid', 'title', 'video_path']
  21. def get_video_ai_tags(video_id, asr_file, video_info):
  22. try:
  23. st_time = time.time()
  24. log_message = {
  25. 'videoId': int(video_id),
  26. }
  27. title = video_info.get('title')
  28. log_message['videoPath'] = video_info.get('video_path')
  29. log_message['title'] = video_info.get('title')
  30. # 1. 获取asr结果
  31. # asr_res_initial = get_whisper_asr(video=video_file)
  32. with open(asr_file, 'r', encoding='utf-8') as rf:
  33. asr_res_initial = rf.read()
  34. log_message['asrRes'] = asr_res_initial
  35. # 2. 判断asr识别的文本是否有效
  36. validity = asr_validity_discrimination(text=asr_res_initial)
  37. log_message['asrValidity'] = validity
  38. if validity is True:
  39. # 3. 对asr结果进行清洗
  40. asr_res = asr_res_initial.replace('\n', '')
  41. for stop_word in config_.STOP_WORDS:
  42. asr_res = asr_res.replace(stop_word, '')
  43. # token限制: 字数 <= 2500
  44. asr_res = asr_res[-2500:]
  45. # 4. gpt产出结果
  46. # 4.1 gpt产出summary, keywords,
  47. prompt1 = f"{config_.GPT_PROMPT['tags']['prompt6']}{asr_res.strip()}"
  48. log_message['gptPromptSummaryKeywords'] = prompt1
  49. gpt_res1 = request_gpt(prompt=prompt1)
  50. log_message['gptResSummaryKeywords'] = gpt_res1
  51. if gpt_res1 is not None:
  52. # 4.2 获取summary, keywords, title进行分类
  53. try:
  54. gpt_res1_json = json.loads(gpt_res1)
  55. summary = gpt_res1_json['summary']
  56. keywords = gpt_res1_json['keywords']
  57. log_message['summary'] = summary
  58. log_message['keywords'] = str(keywords)
  59. # TODO 三个 prompt 拆分成三个请求
  60. prompt2_param = f"标题:{title}\n概况:{summary}\n关键词:{keywords}"
  61. prompt2 = f"{config_.GPT_PROMPT['tags']['prompt8']}{prompt2_param}"
  62. log_message['gptPrompt2'] = prompt2
  63. gpt_res2 = request_gpt(prompt=prompt2)
  64. log_message['gptRes2'] = gpt_res2
  65. prompt3 = f"{config_.GPT_PROMPT['tags']['prompt9']}{prompt2_param}"
  66. log_message['gptPrompt3'] = prompt3
  67. gpt_res3 = request_gpt(prompt=prompt3)
  68. log_message['gptRes3'] = gpt_res3
  69. prompt4 = f"{config_.GPT_PROMPT['tags']['prompt10']}{prompt2_param}"
  70. log_message['gptPrompt4'] = prompt4
  71. gpt_res4 = request_gpt(prompt=prompt4)
  72. log_message['gptRes4'] = gpt_res4
  73. # 5. 解析gpt产出结果
  74. parseRes = praseGptRes(gpt_res2, gpt_res3, gpt_res4)
  75. parseRes['video_id'] = video_id
  76. log_message.update(parseRes)
  77. # 6. 保存结果
  78. mysql_connect.insert_content(parseRes)
  79. except:
  80. log_.error(traceback.format_exc())
  81. pass
  82. else:
  83. pass
  84. log_message['executeTime'] = (time.time() - st_time) * 1000
  85. log_.info(log_message)
  86. except Exception as e:
  87. log_.error(e)
  88. log_.error(traceback.format_exc())
  89. def praseGptRes(gpt_res2, gpt_res3, gpt_res4):
  90. result = {}
  91. if gpt_res2 is not None:
  92. try:
  93. res2 = json.loads(gpt_res2)
  94. result['key_words'] = res2['key_words']
  95. result['search_keys'] = res2['search_keys']
  96. result['extra_keys'] = res2['extra_keys']
  97. except:
  98. pass
  99. if gpt_res3 is not None:
  100. try:
  101. res3 = json.loads(gpt_res3)
  102. result['tone'] = res3['tone']
  103. result['target_audience'] = res3['target_audience']
  104. result['target_age'] = res3['target_age']
  105. except:
  106. pass
  107. if gpt_res4 is not None:
  108. try:
  109. res4 = json.loads(gpt_res4)
  110. result['category'] = res4['category']
  111. result['target_gender'] = res4['target_gender']
  112. result['address'] = res4['address']
  113. result['theme'] = res4['theme']
  114. except:
  115. pass
  116. return result
  117. def process(video_id, video_info, download_folder):
  118. if video_info.get(video_id, None) is None:
  119. shutil.rmtree(os.path.join(download_folder, video_id))
  120. else:
  121. video_folder = os.path.join(download_folder, video_id)
  122. for filename in os.listdir(video_folder):
  123. video_type = filename.split('.')[-1]
  124. if video_type in ['mp4', 'm3u8']:
  125. video_file = os.path.join(video_folder, filename)
  126. get_video_ai_tags(
  127. video_id=video_id, video_file=video_file, video_info=video_info.get(video_id))
  128. # 将处理过的视频进行删除
  129. shutil.rmtree(os.path.join(download_folder, video_id))
  130. else:
  131. shutil.rmtree(os.path.join(download_folder, video_id))
  132. def ai_tags(project, table, dt):
  133. # 获取特征数据
  134. feature_df = get_feature_data(
  135. project=project, table=table, dt=dt, features=features)
  136. video_id_list = feature_df['videoid'].to_list()
  137. video_info = {}
  138. for video_id in video_id_list:
  139. title = feature_df[feature_df['videoid']
  140. == video_id]['title'].values[0]
  141. video_path = feature_df[feature_df['videoid']
  142. == video_id]['video_path'].values[0]
  143. if title is None:
  144. continue
  145. title = title.strip()
  146. if len(title) > 0:
  147. video_info[video_id] = {'title': title, 'video_path': video_path}
  148. # print(video_id, title)
  149. print(len(video_info))
  150. # 获取已下载视频
  151. download_folder = 'videos'
  152. retry = 0
  153. while retry < 3:
  154. video_folder_list = os.listdir(download_folder)
  155. if len(video_folder_list) < 2:
  156. retry += 1
  157. time.sleep(60)
  158. continue
  159. # pool = multiprocessing.Pool(processes=5)
  160. # for video_id in video_folder_list:
  161. # if video_id not in video_id_list:
  162. # continue
  163. # pool.apply_async(
  164. # func=process,
  165. # args=(video_id, video_info, download_folder)
  166. # )
  167. # pool.close()
  168. # pool.join()
  169. for video_id in video_folder_list:
  170. if video_id not in video_id_list:
  171. continue
  172. if video_info.get(video_id, None) is None:
  173. shutil.rmtree(os.path.join(download_folder, video_id))
  174. else:
  175. video_folder = os.path.join(download_folder, video_id)
  176. for filename in os.listdir(video_folder):
  177. video_type = filename.split('.')[-1]
  178. if video_type in ['mp4', 'm3u8']:
  179. video_file = os.path.join(video_folder, filename)
  180. get_video_ai_tags(
  181. video_id=video_id, video_file=video_file, video_info=video_info.get(video_id))
  182. # 将处理过的视频进行删除
  183. shutil.rmtree(os.path.join(download_folder, video_id))
  184. else:
  185. shutil.rmtree(os.path.join(download_folder, video_id))
  186. def ai_tags_new(project, table, dt):
  187. # 获取特征数据
  188. feature_df = get_feature_data(
  189. project=project, table=table, dt=dt, features=features)
  190. video_id_list = feature_df['videoid'].to_list()
  191. video_info = {}
  192. for video_id in video_id_list:
  193. title = feature_df[feature_df['videoid']
  194. == video_id]['title'].values[0]
  195. video_path = feature_df[feature_df['videoid']
  196. == video_id]['video_path'].values[0]
  197. if title is None:
  198. continue
  199. title = title.strip()
  200. if len(title) > 0:
  201. video_info[video_id] = {'title': title, 'video_path': video_path}
  202. # print(video_id, title)
  203. print(len(video_info))
  204. # 获取已asr识别的视频
  205. asr_folder = 'asr_res'
  206. retry = 0
  207. while retry < 30:
  208. asr_file_list = os.listdir(asr_folder)
  209. if len(asr_file_list) < 1:
  210. retry += 1
  211. time.sleep(60)
  212. continue
  213. retry = 0
  214. for asr_filename in asr_file_list:
  215. video_id = asr_filename[:-4]
  216. if video_id not in video_id_list:
  217. continue
  218. asr_file = os.path.join(asr_folder, asr_filename)
  219. if video_info.get(video_id, None) is None:
  220. os.remove(asr_file)
  221. else:
  222. get_video_ai_tags(
  223. video_id=video_id, asr_file=asr_file, video_info=video_info.get(video_id))
  224. os.remove(asr_file)
  225. def timer_check():
  226. try:
  227. project = config_.DAILY_VIDEO['project']
  228. table = config_.DAILY_VIDEO['table']
  229. now_date = datetime.datetime.today()
  230. print(f"now_date: {datetime.datetime.strftime(now_date, '%Y%m%d')}")
  231. dt = datetime.datetime.strftime(
  232. now_date-datetime.timedelta(days=1), '%Y%m%d')
  233. # 查看数据是否已准备好
  234. data_count = data_check(project=project, table=table, dt=dt)
  235. if data_count > 0:
  236. print(f'videos count = {data_count}')
  237. asr_folder = 'asr_res'
  238. if not os.path.exists(asr_folder):
  239. # 1分钟后重新检查
  240. Timer(60, timer_check).start()
  241. else:
  242. # 数据准备好,进行aiTag
  243. ai_tags_new(project=project, table=table, dt=dt)
  244. print(f"videos ai tag finished!")
  245. else:
  246. # 数据没准备好,1分钟后重新检查
  247. Timer(60, timer_check).start()
  248. except Exception as e:
  249. print(
  250. f"视频ai打标签失败, exception: {e}, traceback: {traceback.format_exc()}")
  251. if __name__ == '__main__':
  252. # timer_check()
  253. feature_df = getVideoInfoInXlxs('past_videos.xlsx')
  254. video_id_list = feature_df['videoid'].to_list()
  255. video_info = {}
  256. for video_id in video_id_list:
  257. titleObj = feature_df[feature_df['videoid']
  258. == video_id]['title'].values[0]
  259. video_path = feature_df[feature_df['videoid']
  260. == video_id]['video_path'].values[0]
  261. title = str(titleObj)
  262. if title is None:
  263. continue
  264. title = title.strip()
  265. if len(title) > 0:
  266. video_info[video_id] = {'title': title, 'video_path': video_path}
  267. # print(video_id, title)
  268. print(len(video_info))
  269. # 获取已asr识别的视频
  270. asr_folder = 'asr_res'
  271. retry = 0
  272. while retry < 30:
  273. asr_file_list = os.listdir(asr_folder)
  274. if len(asr_file_list) < 1:
  275. retry += 1
  276. time.sleep(60)
  277. continue
  278. retry = 0
  279. for asr_filename in asr_file_list:
  280. video_id = asr_filename[:-4]
  281. if video_id not in video_id_list:
  282. continue
  283. asr_file = os.path.join(asr_folder, asr_filename)
  284. if video_info.get(video_id, None) is None:
  285. os.remove(asr_file)
  286. else:
  287. get_video_ai_tags(
  288. video_id=video_id, asr_file=asr_file, video_info=video_info.get(video_id))
  289. os.remove(asr_file)