|
|
@@ -67,18 +67,28 @@ def get_model(
|
|
|
def process_batch(files: list[Path], model) -> float:
|
|
|
wavs = []
|
|
|
audio_lengths = []
|
|
|
+ new_files = []
|
|
|
max_length = total_time = 0
|
|
|
|
|
|
for file in files:
|
|
|
- wav, sr = torchaudio.load(file)
|
|
|
+ try:
|
|
|
+ wav, sr = torchaudio.load(file)
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error reading {file}: {e}")
|
|
|
+ continue
|
|
|
+
|
|
|
if wav.shape[0] > 1:
|
|
|
wav = wav.mean(dim=0, keepdim=True)
|
|
|
|
|
|
wav = torchaudio.functional.resample(wav.cuda(), sr, model.sampling_rate)[0]
|
|
|
- wavs.append(wav)
|
|
|
total_time += len(wav) / model.sampling_rate
|
|
|
max_length = max(max_length, len(wav))
|
|
|
+
|
|
|
+ wavs.append(wav)
|
|
|
audio_lengths.append(len(wav))
|
|
|
+ new_files.append(file)
|
|
|
+
|
|
|
+ files = new_files
|
|
|
|
|
|
# Pad to max length
|
|
|
for i, wav in enumerate(wavs):
|
|
|
@@ -161,6 +171,8 @@ def main(
|
|
|
else:
|
|
|
files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=True)
|
|
|
|
|
|
+ print(f"Found {len(files)} files")
|
|
|
+ files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()]
|
|
|
Random(42).shuffle(files)
|
|
|
|
|
|
total_files = len(files)
|