123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314 |
- import os
- import shutil
- import json
- import datetime
- import time
- import traceback
- import requests
- import multiprocessing
- import ODPSQueryUtil
- from threading import Timer
- from utils import data_check, get_feature_data, asr_validity_discrimination
- from whisper_asr import get_whisper_asr
- from gpt_tag import request_gpt
- from config import set_config
- from log import Log
- import mysql_connect
- config_ = set_config()
- log_ = Log()
- features = ['videoid', 'title', 'video_path']
- def get_video_ai_tags(video_id, asr_file, video_info):
- try:
- st_time = time.time()
- log_message = {
- 'videoId': int(video_id),
- }
- title = video_info.get('title')
- log_message['videoPath'] = video_info.get('video_path')
- log_message['title'] = video_info.get('title')
- # 1. 获取asr结果
- # asr_res_initial = get_whisper_asr(video=video_file)
- with open(asr_file, 'r', encoding='utf-8') as rf:
- asr_res_initial = rf.read()
- log_message['asrRes'] = asr_res_initial
- # 2. 判断asr识别的文本是否有效
- validity = asr_validity_discrimination(text=asr_res_initial)
- log_message['asrValidity'] = validity
- if validity is True:
- # 3. 对asr结果进行清洗
- asr_res = asr_res_initial.replace('\n', '')
- for stop_word in config_.STOP_WORDS:
- asr_res = asr_res.replace(stop_word, '')
- # token限制: 字数 <= 2500
- asr_res = asr_res[-2500:]
- # 4. gpt产出结果
- # 4.1 gpt产出summary, keywords,
- prompt1 = f"{config_.GPT_PROMPT['tags']['prompt6']}{asr_res.strip()}"
- log_message['gptPromptSummaryKeywords'] = prompt1
- gpt_res1 = request_gpt(prompt=prompt1)
- log_message['gptResSummaryKeywords'] = gpt_res1
- if gpt_res1 is not None:
- # 4.2 获取summary, keywords, title进行分类
- try:
- gpt_res1_json = json.loads(gpt_res1)
- summary = gpt_res1_json['summary']
- keywords = gpt_res1_json['keywords']
- log_message['summary'] = summary
- log_message['keywords'] = str(keywords)
- # TODO 三个 prompt 拆分成三个请求
- prompt2_param = f"标题:{title}\n概况:{summary}\n关键词:{keywords}"
- prompt2 = f"{config_.GPT_PROMPT['tags']['prompt8']}{prompt2_param}"
- log_message['gptPrompt2'] = prompt2
- gpt_res2 = request_gpt(prompt=prompt2)
- log_message['gptRes2'] = gpt_res2
- prompt3 = f"{config_.GPT_PROMPT['tags']['prompt9']}{prompt2_param}"
- log_message['gptPrompt3'] = prompt3
- gpt_res3 = request_gpt(prompt=prompt3)
- log_message['gptRes3'] = gpt_res3
- prompt4 = f"{config_.GPT_PROMPT['tags']['prompt10']}{prompt2_param}"
- log_message['gptPrompt4'] = prompt4
- gpt_res4 = request_gpt(prompt=prompt4)
- log_message['gptRes4'] = gpt_res4
- # 5. 解析gpt产出结果
- parseRes = praseGptRes(gpt_res2, gpt_res3, gpt_res4)
- log_message.update(parseRes)
- # 6. 保存结果
- mysql_connect.insert_content()
- except:
- pass
- else:
- pass
- log_message['executeTime'] = (time.time() - st_time) * 1000
- log_.info(log_message)
- except Exception as e:
- log_.error(e)
- log_.error(traceback.format_exc())
- def praseGptRes(gpt_res2, gpt_res3, gpt_res4):
- result = {}
- if gpt_res2 is not None:
- try:
- res2 = json.loads(gpt_res2)
- result['key_words'] = res2['key_words']
- result['search_keys'] = res2['search_keys']
- result['extra_keys'] = res2['extra_keys']
- except:
- pass
- if gpt_res3 is not None:
- try:
- res3 = json.loads(gpt_res3)
- result['tone'] = res3['tone']
- result['target_audience'] = res3['target_audience']
- result['target_age'] = res3['target_age']
- except:
- pass
- if gpt_res4 is not None:
- try:
- res4 = json.loads(gpt_res4)
- result['category'] = res4['category']
- result['target_gender'] = res4['target_gender']
- result['address'] = res4['address']
- result['theme'] = res4['theme']
- except:
- pass
- return result
- def process(video_id, video_info, download_folder):
- if video_info.get(video_id, None) is None:
- shutil.rmtree(os.path.join(download_folder, video_id))
- else:
- video_folder = os.path.join(download_folder, video_id)
- for filename in os.listdir(video_folder):
- video_type = filename.split('.')[-1]
- if video_type in ['mp4', 'm3u8']:
- video_file = os.path.join(video_folder, filename)
- get_video_ai_tags(
- video_id=video_id, video_file=video_file, video_info=video_info.get(video_id))
- # 将处理过的视频进行删除
- shutil.rmtree(os.path.join(download_folder, video_id))
- else:
- shutil.rmtree(os.path.join(download_folder, video_id))
- def ai_tags(project, table, dt):
- # 获取特征数据
- feature_df = get_feature_data(
- project=project, table=table, dt=dt, features=features)
- video_id_list = feature_df['videoid'].to_list()
- video_info = {}
- for video_id in video_id_list:
- title = feature_df[feature_df['videoid']
- == video_id]['title'].values[0]
- video_path = feature_df[feature_df['videoid']
- == video_id]['video_path'].values[0]
- if title is None:
- continue
- title = title.strip()
- if len(title) > 0:
- video_info[video_id] = {'title': title, 'video_path': video_path}
- # print(video_id, title)
- print(len(video_info))
- # 获取已下载视频
- download_folder = 'videos'
- retry = 0
- while retry < 3:
- video_folder_list = os.listdir(download_folder)
- if len(video_folder_list) < 2:
- retry += 1
- time.sleep(60)
- continue
- # pool = multiprocessing.Pool(processes=5)
- # for video_id in video_folder_list:
- # if video_id not in video_id_list:
- # continue
- # pool.apply_async(
- # func=process,
- # args=(video_id, video_info, download_folder)
- # )
- # pool.close()
- # pool.join()
- for video_id in video_folder_list:
- if video_id not in video_id_list:
- continue
- if video_info.get(video_id, None) is None:
- shutil.rmtree(os.path.join(download_folder, video_id))
- else:
- video_folder = os.path.join(download_folder, video_id)
- for filename in os.listdir(video_folder):
- video_type = filename.split('.')[-1]
- if video_type in ['mp4', 'm3u8']:
- video_file = os.path.join(video_folder, filename)
- get_video_ai_tags(
- video_id=video_id, video_file=video_file, video_info=video_info.get(video_id))
- # 将处理过的视频进行删除
- shutil.rmtree(os.path.join(download_folder, video_id))
- else:
- shutil.rmtree(os.path.join(download_folder, video_id))
- def ai_tags_new(project, table, dt):
- # 获取特征数据
- feature_df = get_feature_data(
- project=project, table=table, dt=dt, features=features)
- video_id_list = feature_df['videoid'].to_list()
- video_info = {}
- for video_id in video_id_list:
- title = feature_df[feature_df['videoid']
- == video_id]['title'].values[0]
- video_path = feature_df[feature_df['videoid']
- == video_id]['video_path'].values[0]
- if title is None:
- continue
- title = title.strip()
- if len(title) > 0:
- video_info[video_id] = {'title': title, 'video_path': video_path}
- # print(video_id, title)
- print(len(video_info))
- # 获取已asr识别的视频
- asr_folder = 'asr_res'
- retry = 0
- while retry < 30:
- asr_file_list = os.listdir(asr_folder)
- if len(asr_file_list) < 1:
- retry += 1
- time.sleep(60)
- continue
- retry = 0
- for asr_filename in asr_file_list:
- video_id = asr_filename[:-4]
- if video_id not in video_id_list:
- continue
- asr_file = os.path.join(asr_folder, asr_filename)
- if video_info.get(video_id, None) is None:
- os.remove(asr_file)
- else:
- get_video_ai_tags(
- video_id=video_id, asr_file=asr_file, video_info=video_info.get(video_id))
- os.remove(asr_file)
- def timer_check():
- try:
- project = config_.DAILY_VIDEO['project']
- table = config_.DAILY_VIDEO['table']
- now_date = datetime.datetime.today()
- print(f"now_date: {datetime.datetime.strftime(now_date, '%Y%m%d')}")
- dt = datetime.datetime.strftime(
- now_date-datetime.timedelta(days=1), '%Y%m%d')
- # 查看数据是否已准备好
- data_count = data_check(project=project, table=table, dt=dt)
- if data_count > 0:
- print(f'videos count = {data_count}')
- asr_folder = 'asr_res'
- if not os.path.exists(asr_folder):
- # 1分钟后重新检查
- Timer(60, timer_check).start()
- else:
- # 数据准备好,进行aiTag
- ai_tags_new(project=project, table=table, dt=dt)
- print(f"videos ai tag finished!")
- else:
- # 数据没准备好,1分钟后重新检查
- Timer(60, timer_check).start()
- except Exception as e:
- print(
- f"视频ai打标签失败, exception: {e}, traceback: {traceback.format_exc()}")
- if __name__ == '__main__':
- # timer_check()
- size = 10000
- for i in range(0, 10000, size):
- print(f"query_videos start i = {i} ...")
- records = ODPSQueryUtil.query_videos(i, size)
- if records is None or len(records) == 0:
- continue
- print(f"Got {len(records)} records")
- video_info = {}
- # 遍历 records,将每个视频的信息添加到字典中
- for record in records:
- # 将 video_id 从字符串转换为整数,这里假设 video_id 格式总是 "vid" 后跟数字
- video_id = int(record['videoid'])
- title = record['title']
- video_path = record['video_path']
- # 使用 video_id 作为键,其他信息作为值
- video_info[video_id] = {'title': title, 'video_path': video_path}
- # 打印结果查看
- print(video_info)
- asr_folder = 'asr_res'
- retry = 0
- while retry < 30:
- asr_file_list = os.listdir(asr_folder)
- if len(asr_file_list) < 1:
- retry += 1
- time.sleep(60)
- continue
- retry = 0
- for asr_filename in asr_file_list:
- video_id = int(asr_filename[:-4])
- if video_id not in video_info:
- continue
- asr_file = os.path.join(asr_folder, asr_filename)
- if video_info.get(video_id, None) is None:
- os.remove(asr_file)
- else:
- get_video_ai_tags(
- video_id=video_id, asr_file=asr_file, video_info=video_info.get(video_id))
- os.remove(asr_file)
- # get_video_ai_tags(16598277, 'aigc-test/asr_res/16598277.txt',
- # {'title': '九九重阳节送祝福🚩', 'video_path': '视频路径'})
|