spectrogram.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import torch
  2. import torchaudio.functional as F
  3. from torch import Tensor, nn
  4. from torchaudio.transforms import MelScale
  5. class LinearSpectrogram(nn.Module):
  6. def __init__(
  7. self,
  8. n_fft=2048,
  9. win_length=2048,
  10. hop_length=512,
  11. center=False,
  12. mode="pow2_sqrt",
  13. ):
  14. super().__init__()
  15. self.n_fft = n_fft
  16. self.win_length = win_length
  17. self.hop_length = hop_length
  18. self.center = center
  19. self.mode = mode
  20. self.return_complex = True
  21. self.register_buffer("window", torch.hann_window(win_length), persistent=False)
  22. def forward(self, y: Tensor) -> Tensor:
  23. if y.ndim == 3:
  24. y = y.squeeze(1)
  25. y = torch.nn.functional.pad(
  26. y.unsqueeze(1),
  27. (
  28. (self.win_length - self.hop_length) // 2,
  29. (self.win_length - self.hop_length + 1) // 2,
  30. ),
  31. mode="reflect",
  32. ).squeeze(1)
  33. spec = torch.stft(
  34. y,
  35. self.n_fft,
  36. hop_length=self.hop_length,
  37. win_length=self.win_length,
  38. window=self.window,
  39. center=self.center,
  40. pad_mode="reflect",
  41. normalized=False,
  42. onesided=True,
  43. return_complex=self.return_complex,
  44. )
  45. if self.return_complex:
  46. spec = torch.view_as_real(spec)
  47. if self.mode == "pow2_sqrt":
  48. spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
  49. return spec
  50. class LogMelSpectrogram(nn.Module):
  51. def __init__(
  52. self,
  53. sample_rate=44100,
  54. n_fft=2048,
  55. win_length=2048,
  56. hop_length=512,
  57. n_mels=128,
  58. center=False,
  59. f_min=0.0,
  60. f_max=None,
  61. ):
  62. super().__init__()
  63. self.sample_rate = sample_rate
  64. self.n_fft = n_fft
  65. self.win_length = win_length
  66. self.hop_length = hop_length
  67. self.center = center
  68. self.n_mels = n_mels
  69. self.f_min = f_min
  70. self.f_max = f_max or float(sample_rate // 2)
  71. self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
  72. fb = F.melscale_fbanks(
  73. n_freqs=self.n_fft // 2 + 1,
  74. f_min=self.f_min,
  75. f_max=self.f_max,
  76. n_mels=self.n_mels,
  77. sample_rate=self.sample_rate,
  78. norm="slaney",
  79. mel_scale="slaney",
  80. )
  81. self.register_buffer(
  82. "fb",
  83. fb,
  84. persistent=False,
  85. )
  86. def compress(self, x: Tensor) -> Tensor:
  87. return torch.log(torch.clamp(x, min=1e-5))
  88. def decompress(self, x: Tensor) -> Tensor:
  89. return torch.exp(x)
  90. def apply_mel_scale(self, x: Tensor) -> Tensor:
  91. return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2)
  92. def forward(
  93. self, x: Tensor, return_linear: bool = False, sample_rate: int = None
  94. ) -> Tensor:
  95. if sample_rate is not None and sample_rate != self.sample_rate:
  96. x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate)
  97. linear = self.spectrogram(x)
  98. x = self.apply_mel_scale(linear)
  99. x = self.compress(x)
  100. if return_linear:
  101. return x, self.compress(linear)
  102. return x