Преглед изворни кода

fix: add compatibility check for models without spec_transform attribute (#1056)

* fix: add compatibility check for models without spec_transform attribute

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Xudong Huang пре 9 месеци
родитељ
комит
feb987fc0e
1 измењених фајлова са 8 додато и 4 уклоњено
  1. 8 4
      tools/server/model_utils.py

+ 8 - 4
tools/server/model_utils.py

@@ -15,11 +15,15 @@ HUGE_GAP_THRESHOLD = 4000
 @torch.no_grad()
 @torch.autocast(device_type="cuda", dtype=torch.half)
 def batch_encode(model, audios_list: list[bytes]):
+    # Get sample rate from model
+    if hasattr(model, "spec_transform"):
+        sample_rate = model.spec_transform.sample_rate
+    else:
+        sample_rate = model.sample_rate
+
     audios: list[torch.Tensor] = [
         (
-            torch.from_numpy(
-                librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0]
-            )[None]
+            torch.from_numpy(librosa.load(io.BytesIO(audio), sr=sample_rate)[0])[None]
             if isinstance(audio, bytes)
             else audio
         )
@@ -29,7 +33,7 @@ def batch_encode(model, audios_list: list[bytes]):
     lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
     max_length = lengths.max().item()
 
-    print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s")
+    print(f"Encode max length: {max_length / sample_rate:.2f}s")
 
     padded = torch.stack(
         [