|
|
@@ -15,11 +15,15 @@ HUGE_GAP_THRESHOLD = 4000
|
|
|
@torch.no_grad()
|
|
|
@torch.autocast(device_type="cuda", dtype=torch.half)
|
|
|
def batch_encode(model, audios_list: list[bytes]):
|
|
|
+ # Get sample rate from model
|
|
|
+ if hasattr(model, "spec_transform"):
|
|
|
+ sample_rate = model.spec_transform.sample_rate
|
|
|
+ else:
|
|
|
+ sample_rate = model.sample_rate
|
|
|
+
|
|
|
audios: list[torch.Tensor] = [
|
|
|
(
|
|
|
- torch.from_numpy(
|
|
|
- librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0]
|
|
|
- )[None]
|
|
|
+ torch.from_numpy(librosa.load(io.BytesIO(audio), sr=sample_rate)[0])[None]
|
|
|
if isinstance(audio, bytes)
|
|
|
else audio
|
|
|
)
|
|
|
@@ -29,7 +33,7 @@ def batch_encode(model, audios_list: list[bytes]):
|
|
|
lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
|
|
|
max_length = lengths.max().item()
|
|
|
|
|
|
- print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s")
|
|
|
+ print(f"Encode max length: {max_length / sample_rate:.2f}s")
|
|
|
|
|
|
padded = torch.stack(
|
|
|
[
|