model_utils.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import io
  2. import re
  3. import librosa
  4. import torch
  5. import torchaudio
  6. from cachetools import LRUCache, cached
  7. CACHE_MAXSIZE = 10000
  8. MICRO_BATCH_SIZE = 8
  9. ASR_SAMPLE_RATE = 16000
  10. HUGE_GAP_THRESHOLD = 4000
  11. @torch.no_grad()
  12. @torch.autocast(device_type="cuda", dtype=torch.half)
  13. def batch_encode(model, audios_list: list[bytes]):
  14. audios: list[torch.Tensor] = [
  15. (
  16. torch.from_numpy(
  17. librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0]
  18. )[None]
  19. if isinstance(audio, bytes)
  20. else audio
  21. )
  22. for audio in audios_list
  23. ]
  24. lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
  25. max_length = lengths.max().item()
  26. print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s")
  27. padded = torch.stack(
  28. [
  29. torch.nn.functional.pad(audio, (0, int(max_length - audio.shape[-1])))
  30. for audio in audios
  31. ]
  32. ).to(model.device)
  33. features, feature_lengths = model.encode(padded, audio_lengths=lengths)
  34. features, feature_lengths = features.cpu(), feature_lengths.cpu()
  35. return [feature[..., :length] for feature, length in zip(features, feature_lengths)]
  36. @cached(
  37. cache=LRUCache(maxsize=CACHE_MAXSIZE),
  38. key=lambda model, audios: (model.device, tuple(audios)),
  39. )
  40. def cached_vqgan_batch_encode(model, audios: list[bytes]):
  41. return batch_encode(model, audios)
  42. @torch.no_grad()
  43. @torch.autocast(device_type="cuda", dtype=torch.half)
  44. def vqgan_decode(model, features):
  45. lengths = torch.tensor(
  46. [feature.shape[-1] for feature in features], device=model.device
  47. )
  48. max_length = lengths.max().item()
  49. padded = torch.stack(
  50. [
  51. torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1]))
  52. for feature in features
  53. ]
  54. ).to(model.device)
  55. # If bs too large, we do micro batch decode
  56. audios, audio_lengths = [], []
  57. for i in range(0, padded.shape[0], MICRO_BATCH_SIZE):
  58. audio, audio_length = model.decode(
  59. padded[i : i + MICRO_BATCH_SIZE],
  60. feature_lengths=lengths[i : i + MICRO_BATCH_SIZE],
  61. )
  62. audios.append(audio)
  63. audio_lengths.append(audio_length)
  64. audios = torch.cat(audios, dim=0)
  65. audio_lengths = torch.cat(audio_lengths, dim=0)
  66. audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
  67. return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
  68. @torch.no_grad()
  69. def batch_asr(model, lock, audios, sr, language="auto"):
  70. resampled_audios = []
  71. for audio in audios:
  72. audio = torchaudio.functional.resample(audio, sr, ASR_SAMPLE_RATE)
  73. assert audio.ndim == 1
  74. resampled_audios.append(audio)
  75. with lock:
  76. res = model.generate(
  77. input=resampled_audios,
  78. batch_size=len(resampled_audios),
  79. language=language,
  80. use_itn=True,
  81. )
  82. results = []
  83. for r, audio in zip(res, audios):
  84. text = r["text"]
  85. text = re.sub(r"<\|.*?\|>", "", text)
  86. duration = len(audio) / sr * 1000
  87. huge_gap = False
  88. if "timestamp" in r and len(r["timestamp"]) > 2:
  89. for timestamp_a, timestamp_b in zip(
  90. r["timestamp"][:-1], r["timestamp"][1:]
  91. ):
  92. # If there is a gap of more than 4 seconds, we consider it as a huge gap
  93. if timestamp_b[0] - timestamp_a[1] > HUGE_GAP_THRESHOLD:
  94. huge_gap = True
  95. break
  96. # Doesn't make sense to have a huge gap at the end
  97. if duration - r["timestamp"][-1][1] > HUGE_GAP_THRESHOLD:
  98. huge_gap = True
  99. results.append(
  100. {
  101. "text": text,
  102. "duration": duration,
  103. "huge_gap": huge_gap,
  104. }
  105. )
  106. return results