model_utils.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import io
  2. import re
  3. import librosa
  4. import torch
  5. import torchaudio
  6. from cachetools import LRUCache, cached
  7. CACHE_MAXSIZE = 10000
  8. MICRO_BATCH_SIZE = 8
  9. ASR_SAMPLE_RATE = 16000
  10. HUGE_GAP_THRESHOLD = 4000
  11. @torch.no_grad()
  12. @torch.autocast(device_type="cuda", dtype=torch.half)
  13. def batch_encode(model, audios_list: list[bytes]):
  14. # Get sample rate from model
  15. if hasattr(model, "spec_transform"):
  16. sample_rate = model.spec_transform.sample_rate
  17. else:
  18. sample_rate = model.sample_rate
  19. audios: list[torch.Tensor] = [
  20. (
  21. torch.from_numpy(librosa.load(io.BytesIO(audio), sr=sample_rate)[0])[None]
  22. if isinstance(audio, bytes)
  23. else audio
  24. )
  25. for audio in audios_list
  26. ]
  27. lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
  28. max_length = lengths.max().item()
  29. print(f"Encode max length: {max_length / sample_rate:.2f}s")
  30. padded = torch.stack(
  31. [
  32. torch.nn.functional.pad(audio, (0, int(max_length - audio.shape[-1])))
  33. for audio in audios
  34. ]
  35. ).to(model.device)
  36. features, feature_lengths = model.encode(padded, audio_lengths=lengths)
  37. features, feature_lengths = features.cpu(), feature_lengths.cpu()
  38. return [feature[..., :length] for feature, length in zip(features, feature_lengths)]
  39. @cached(
  40. cache=LRUCache(maxsize=CACHE_MAXSIZE),
  41. key=lambda model, audios: (model.device, tuple(audios)),
  42. )
  43. def cached_vqgan_batch_encode(model, audios: list[bytes]):
  44. return batch_encode(model, audios)
  45. @torch.no_grad()
  46. @torch.autocast(device_type="cuda", dtype=torch.half)
  47. def batch_vqgan_decode(model, features):
  48. lengths = torch.tensor(
  49. [feature.shape[-1] for feature in features], device=model.device
  50. )
  51. max_length = lengths.max().item()
  52. padded = torch.stack(
  53. [
  54. torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1]))
  55. for feature in features
  56. ]
  57. ).to(model.device)
  58. # If bs too large, we do micro batch decode
  59. audios, audio_lengths = [], []
  60. for i in range(0, padded.shape[0], MICRO_BATCH_SIZE):
  61. audio, audio_length = model.decode(
  62. padded[i : i + MICRO_BATCH_SIZE],
  63. feature_lengths=lengths[i : i + MICRO_BATCH_SIZE],
  64. )
  65. audios.append(audio)
  66. audio_lengths.append(audio_length)
  67. audios = torch.cat(audios, dim=0)
  68. audio_lengths = torch.cat(audio_lengths, dim=0)
  69. audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
  70. return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]