Selaa lähdekoodia

update with multiprocessing

liqian 1 vuosi sitten
vanhempi
commit
6ff42ab49a
1 muutettua tiedostoa jossa 38 lisäystä ja 14 poistoa
  1. 38 14
      ai_tag_task.py

+ 38 - 14
ai_tag_task.py

@@ -5,6 +5,7 @@ import datetime
 import time
 import traceback
 import requests
+import multiprocessing
 from threading import Timer
 from utils import data_check, get_feature_data, asr_validity_discrimination
 from whisper_asr import get_whisper_asr
@@ -52,7 +53,7 @@ def get_video_ai_tags(video_id, video_file, video_info):
                     summary = gpt_res1_json['summary']
                     keywords = gpt_res1_json['keywords']
                     log_message['summary'] = summary
-                    log_message['keywords'] = keywords
+                    log_message['keywords'] = str(keywords)
                     prompt2_param = f"标题:{title}\n概况:{summary}\n关键词:{keywords}"
                     prompt2 = f"{config_.GPT_PROMPT['tags']['prompt7']}{prompt2_param}"
                     log_message['gptPromptTag'] = prompt2
@@ -87,6 +88,22 @@ def get_video_ai_tags(video_id, video_file, video_info):
         log_.error(traceback.format_exc())
 
 
+def process(video_id, video_info, download_folder):
+    if video_info.get(video_id, None) is None:
+        shutil.rmtree(os.path.join(download_folder, video_id))
+    else:
+        video_folder = os.path.join(download_folder, video_id)
+        for filename in os.listdir(video_folder):
+            video_type = filename.split('.')[-1]
+            if video_type in ['mp4', 'm3u8']:
+                video_file = os.path.join(video_folder, filename)
+                get_video_ai_tags(video_id=video_id, video_file=video_file, video_info=video_info.get(video_id))
+                # 将处理过的视频进行删除
+                shutil.rmtree(os.path.join(download_folder, video_id))
+            else:
+                shutil.rmtree(os.path.join(download_folder, video_id))
+
+
 def ai_tags(project, table, dt):
     # 获取特征数据
     feature_df = get_feature_data(project=project, table=table, dt=dt, features=features)
@@ -111,22 +128,29 @@ def ai_tags(project, table, dt):
             retry += 1
             time.sleep(60)
             continue
+        pool = multiprocessing.Pool(processes=5)
         for video_id in video_folder_list:
             if video_id not in video_id_list:
                 continue
-            if video_info.get(video_id, None) is None:
-                shutil.rmtree(os.path.join(download_folder, video_id))
-            else:
-                video_folder = os.path.join(download_folder, video_id)
-                for filename in os.listdir(video_folder):
-                    video_type = filename.split('.')[-1]
-                    if video_type in ['mp4', 'm3u8']:
-                        video_file = os.path.join(video_folder, filename)
-                        get_video_ai_tags(video_id=video_id, video_file=video_file, video_info=video_info.get(video_id))
-                        # 将处理过的视频进行删除
-                        shutil.rmtree(os.path.join(download_folder, video_id))
-                    else:
-                        shutil.rmtree(os.path.join(download_folder, video_id))
+            pool.apply_async(
+                func=process,
+                args=(video_id, video_info, download_folder)
+            )
+        pool.close()
+        pool.join()
+            # if video_info.get(video_id, None) is None:
+            #     shutil.rmtree(os.path.join(download_folder, video_id))
+            # else:
+            #     video_folder = os.path.join(download_folder, video_id)
+            #     for filename in os.listdir(video_folder):
+            #         video_type = filename.split('.')[-1]
+            #         if video_type in ['mp4', 'm3u8']:
+            #             video_file = os.path.join(video_folder, filename)
+            #             get_video_ai_tags(video_id=video_id, video_file=video_file, video_info=video_info.get(video_id))
+            #             # 将处理过的视频进行删除
+            #             shutil.rmtree(os.path.join(download_folder, video_id))
+            #         else:
+            #             shutil.rmtree(os.path.join(download_folder, video_id))
 
 
 def timer_check():