| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122 |
- import torch
- import torchaudio.functional as F
- from torch import Tensor, nn
- from torchaudio.transforms import MelScale
- class LinearSpectrogram(nn.Module):
- def __init__(
- self,
- n_fft=2048,
- win_length=2048,
- hop_length=512,
- center=False,
- mode="pow2_sqrt",
- ):
- super().__init__()
- self.n_fft = n_fft
- self.win_length = win_length
- self.hop_length = hop_length
- self.center = center
- self.mode = mode
- self.register_buffer("window", torch.hann_window(win_length), persistent=False)
- def forward(self, y: Tensor) -> Tensor:
- if y.ndim == 3:
- y = y.squeeze(1)
- y = torch.nn.functional.pad(
- y.unsqueeze(1),
- (
- (self.win_length - self.hop_length) // 2,
- (self.win_length - self.hop_length + 1) // 2,
- ),
- mode="reflect",
- ).squeeze(1)
- spec = torch.stft(
- y,
- self.n_fft,
- hop_length=self.hop_length,
- win_length=self.win_length,
- window=self.window,
- center=self.center,
- pad_mode="reflect",
- normalized=False,
- onesided=True,
- return_complex=True,
- )
- spec = torch.view_as_real(spec)
- if self.mode == "pow2_sqrt":
- spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
- return spec
- class LogMelSpectrogram(nn.Module):
- def __init__(
- self,
- sample_rate=44100,
- n_fft=2048,
- win_length=2048,
- hop_length=512,
- n_mels=128,
- center=False,
- f_min=0.0,
- f_max=None,
- ):
- super().__init__()
- self.sample_rate = sample_rate
- self.n_fft = n_fft
- self.win_length = win_length
- self.hop_length = hop_length
- self.center = center
- self.n_mels = n_mels
- self.f_min = f_min
- self.f_max = f_max or float(sample_rate // 2)
- self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
- fb = F.melscale_fbanks(
- n_freqs=self.n_fft // 2 + 1,
- f_min=self.f_min,
- f_max=self.f_max,
- n_mels=self.n_mels,
- sample_rate=self.sample_rate,
- norm="slaney",
- mel_scale="slaney",
- )
- self.register_buffer(
- "fb",
- fb,
- persistent=False,
- )
- def compress(self, x: Tensor) -> Tensor:
- return torch.log(torch.clamp(x, min=1e-5))
- def decompress(self, x: Tensor) -> Tensor:
- return torch.exp(x)
- def apply_mel_scale(self, x: Tensor) -> Tensor:
- return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2)
- def forward(
- self, x: Tensor, return_linear: bool = False, sample_rate: int = None
- ) -> Tensor:
- if sample_rate is not None and sample_rate != self.sample_rate:
- x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate)
- linear = self.spectrogram(x)
- x = self.apply_mel_scale(linear)
- x = self.compress(x)
- if return_linear:
- return x, self.compress(linear)
- return x
|