فهرست منبع

handle bad files

Lengyue 2 سال پیش
والد
کامیت
6ac74ddea5
1فایلهای تغییر یافته به همراه14 افزوده شده و 2 حذف شده
  1. 14 2
      tools/vqgan/extract_vq.py

+ 14 - 2
tools/vqgan/extract_vq.py

@@ -67,18 +67,28 @@ def get_model(
 def process_batch(files: list[Path], model) -> float:
     wavs = []
     audio_lengths = []
+    new_files = []
     max_length = total_time = 0
 
     for file in files:
-        wav, sr = torchaudio.load(file)
+        try:
+            wav, sr = torchaudio.load(file)
+        except Exception as e:
+            logger.error(f"Error reading {file}: {e}")
+            continue
+
         if wav.shape[0] > 1:
             wav = wav.mean(dim=0, keepdim=True)
 
         wav = torchaudio.functional.resample(wav.cuda(), sr, model.sampling_rate)[0]
-        wavs.append(wav)
         total_time += len(wav) / model.sampling_rate
         max_length = max(max_length, len(wav))
+
+        wavs.append(wav)
         audio_lengths.append(len(wav))
+        new_files.append(file)
+
+    files = new_files
 
     # Pad to max length
     for i, wav in enumerate(wavs):
@@ -161,6 +171,8 @@ def main(
     else:
         files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=True)
 
+    print(f"Found {len(files)} files")
+    files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()]
     Random(42).shuffle(files)
 
     total_files = len(files)