spectrogram.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import torch
  2. from torch import Tensor, nn
  3. from torchaudio.transforms import MelScale
  4. class LinearSpectrogram(nn.Module):
  5. def __init__(
  6. self,
  7. n_fft=2048,
  8. win_length=2048,
  9. hop_length=512,
  10. center=False,
  11. mode="pow2_sqrt",
  12. ):
  13. super().__init__()
  14. self.n_fft = n_fft
  15. self.win_length = win_length
  16. self.hop_length = hop_length
  17. self.center = center
  18. self.mode = mode
  19. self.register_buffer("window", torch.hann_window(win_length))
  20. def forward(self, y: Tensor) -> Tensor:
  21. if y.ndim == 3:
  22. y = y.squeeze(1)
  23. y = torch.nn.functional.pad(
  24. y.unsqueeze(1),
  25. (
  26. (self.win_length - self.hop_length) // 2,
  27. (self.win_length - self.hop_length + 1) // 2,
  28. ),
  29. mode="reflect",
  30. ).squeeze(1)
  31. spec = torch.stft(
  32. y,
  33. self.n_fft,
  34. hop_length=self.hop_length,
  35. win_length=self.win_length,
  36. window=self.window,
  37. center=self.center,
  38. pad_mode="reflect",
  39. normalized=False,
  40. onesided=True,
  41. return_complex=True,
  42. )
  43. spec = torch.view_as_real(spec)
  44. if self.mode == "pow2_sqrt":
  45. spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
  46. return spec
  47. class LogMelSpectrogram(nn.Module):
  48. def __init__(
  49. self,
  50. sample_rate=44100,
  51. n_fft=2048,
  52. win_length=2048,
  53. hop_length=512,
  54. n_mels=128,
  55. center=False,
  56. f_min=0.0,
  57. f_max=None,
  58. ):
  59. super().__init__()
  60. self.sample_rate = sample_rate
  61. self.n_fft = n_fft
  62. self.win_length = win_length
  63. self.hop_length = hop_length
  64. self.center = center
  65. self.n_mels = n_mels
  66. self.f_min = f_min
  67. self.f_max = f_max or sample_rate // 2
  68. self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
  69. self.mel_scale = MelScale(
  70. self.n_mels,
  71. self.sample_rate,
  72. self.f_min,
  73. self.f_max,
  74. self.n_fft // 2 + 1,
  75. "slaney",
  76. "slaney",
  77. )
  78. def compress(self, x: Tensor) -> Tensor:
  79. return torch.log(torch.clamp(x, min=1e-5))
  80. def decompress(self, x: Tensor) -> Tensor:
  81. return torch.exp(x)
  82. def forward(self, x: Tensor, return_linear: bool = False) -> Tensor:
  83. linear = self.spectrogram(x)
  84. x = self.mel_scale(linear)
  85. x = self.compress(x)
  86. if return_linear:
  87. return x, self.compress(linear)
  88. return x