Explorar el Código

add multiprocessing

liqian hace 1 año
padre
commit
37a26dd602
Se han modificado 1 ficheros con 51 adiciones y 15 borrados
  1. 51 15
      asr_task.py

+ 51 - 15
asr_task.py

@@ -17,6 +17,24 @@ log_ = Log()
 features = ['videoid', 'title', 'video_path']
 
 
+def get_asr(video_id, download_folder, asr_folder):
+    video_folder = os.path.join(download_folder, video_id)
+    for filename in os.listdir(video_folder):
+        video_type = filename.split('.')[-1]
+        if video_type in ['mp4', 'm3u8']:
+            video_file = os.path.join(video_folder, filename)
+            # 1. asr识别
+            asr_res_initial = get_whisper_asr(video=video_file)
+            print(video_id, asr_res_initial)
+            # 2. 识别结果写入文件
+            asr_path = os.path.join(asr_folder, f"{video_id}.txt")
+            with open(asr_path, 'w', encoding='utf-8') as wf:
+                wf.write(asr_res_initial)
+            # 将处理过的视频进行删除
+            shutil.rmtree(os.path.join(download_folder, video_id))
+            break
+
+
 def asr_process(project, table, dt):
     # 获取特征数据
     feature_df = get_feature_data(project=project, table=table, dt=dt, features=features)
@@ -41,6 +59,33 @@ def asr_process(project, table, dt):
             retry += 1
             time.sleep(60)
             continue
+        retry = 0
+        # for video_id in video_folder_list:
+        #     if video_id not in video_id_list:
+        #         continue
+        #     if video_info.get(video_id, None) is None:
+        #         try:
+        #             shutil.rmtree(os.path.join(download_folder, video_id))
+        #         except:
+        #             continue
+        #     else:
+        #         video_folder = os.path.join(download_folder, video_id)
+        #         for filename in os.listdir(video_folder):
+        #             video_type = filename.split('.')[-1]
+        #             if video_type in ['mp4', 'm3u8']:
+        #                 video_file = os.path.join(video_folder, filename)
+        #                 # 1. asr识别
+        #                 asr_res_initial = get_whisper_asr(video=video_file)
+        #                 print(video_id, asr_res_initial)
+        #                 # 2. 识别结果写入文件
+        #                 asr_path = os.path.join(asr_folder, f"{video_id}.txt")
+        #                 with open(asr_path, 'w', encoding='utf-8') as wf:
+        #                     wf.write(asr_res_initial)
+        #                 # 将处理过的视频进行删除
+        #                 shutil.rmtree(os.path.join(download_folder, video_id))
+        #                 break
+
+        pool = multiprocessing.Pool(processes=2)
         for video_id in video_folder_list:
             if video_id not in video_id_list:
                 continue
@@ -50,21 +95,12 @@ def asr_process(project, table, dt):
                 except:
                     continue
             else:
-                video_folder = os.path.join(download_folder, video_id)
-                for filename in os.listdir(video_folder):
-                    video_type = filename.split('.')[-1]
-                    if video_type in ['mp4', 'm3u8']:
-                        video_file = os.path.join(video_folder, filename)
-                        # 1. asr识别
-                        asr_res_initial = get_whisper_asr(video=video_file)
-                        print(video_id, asr_res_initial)
-                        # 2. 识别结果写入文件
-                        asr_path = os.path.join(asr_folder, f"{video_id}.txt")
-                        with open(asr_path, 'w', encoding='utf-8') as wf:
-                            wf.write(asr_res_initial)
-                        # 将处理过的视频进行删除
-                        shutil.rmtree(os.path.join(download_folder, video_id))
-                        break
+                pool.apply_async(
+                    func=get_asr,
+                    args=(video_id, download_folder, asr_folder)
+                )
+        pool.close()
+        pool.join()
 
 
 def timer_check():