spectrogram.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import torch
  2. import torchaudio.functional as F
  3. from torch import Tensor, nn
  4. from torchaudio.transforms import MelScale
  5. class LinearSpectrogram(nn.Module):
  6. def __init__(
  7. self,
  8. n_fft=2048,
  9. win_length=2048,
  10. hop_length=512,
  11. center=False,
  12. mode="pow2_sqrt",
  13. ):
  14. super().__init__()
  15. self.n_fft = n_fft
  16. self.win_length = win_length
  17. self.hop_length = hop_length
  18. self.center = center
  19. self.mode = mode
  20. self.register_buffer("window", torch.hann_window(win_length), persistent=False)
  21. def forward(self, y: Tensor) -> Tensor:
  22. if y.ndim == 3:
  23. y = y.squeeze(1)
  24. y = torch.nn.functional.pad(
  25. y.unsqueeze(1),
  26. (
  27. (self.win_length - self.hop_length) // 2,
  28. (self.win_length - self.hop_length + 1) // 2,
  29. ),
  30. mode="reflect",
  31. ).squeeze(1)
  32. spec = torch.stft(
  33. y,
  34. self.n_fft,
  35. hop_length=self.hop_length,
  36. win_length=self.win_length,
  37. window=self.window,
  38. center=self.center,
  39. pad_mode="reflect",
  40. normalized=False,
  41. onesided=True,
  42. return_complex=True,
  43. )
  44. spec = torch.view_as_real(spec)
  45. if self.mode == "pow2_sqrt":
  46. spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
  47. return spec
  48. class LogMelSpectrogram(nn.Module):
  49. def __init__(
  50. self,
  51. sample_rate=44100,
  52. n_fft=2048,
  53. win_length=2048,
  54. hop_length=512,
  55. n_mels=128,
  56. center=False,
  57. f_min=0.0,
  58. f_max=None,
  59. ):
  60. super().__init__()
  61. self.sample_rate = sample_rate
  62. self.n_fft = n_fft
  63. self.win_length = win_length
  64. self.hop_length = hop_length
  65. self.center = center
  66. self.n_mels = n_mels
  67. self.f_min = f_min
  68. self.f_max = f_max or float(sample_rate // 2)
  69. self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
  70. fb = F.melscale_fbanks(
  71. n_freqs=self.n_fft // 2 + 1,
  72. f_min=self.f_min,
  73. f_max=self.f_max,
  74. n_mels=self.n_mels,
  75. sample_rate=self.sample_rate,
  76. norm="slaney",
  77. mel_scale="slaney",
  78. )
  79. self.register_buffer(
  80. "fb",
  81. fb,
  82. persistent=False,
  83. )
  84. def compress(self, x: Tensor) -> Tensor:
  85. return torch.log(torch.clamp(x, min=1e-5))
  86. def decompress(self, x: Tensor) -> Tensor:
  87. return torch.exp(x)
  88. def apply_mel_scale(self, x: Tensor) -> Tensor:
  89. return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2)
  90. def forward(
  91. self, x: Tensor, return_linear: bool = False, sample_rate: int = None
  92. ) -> Tensor:
  93. if sample_rate is not None and sample_rate != self.sample_rate:
  94. x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate)
  95. linear = self.spectrogram(x)
  96. x = self.apply_mel_scale(linear)
  97. x = self.compress(x)
  98. if return_linear:
  99. return x, self.compress(linear)
  100. return x