losses.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. import torch
  2. import torch.nn.functional as F
  3. from torch import nn
  4. def feature_loss(fmap_r: list[torch.Tensor], fmap_g: list[torch.Tensor]):
  5. loss = 0
  6. for dr, dg in zip(fmap_r, fmap_g):
  7. dr = dr.float().detach()
  8. dg = dg.float()
  9. loss += torch.mean(torch.abs(dr - dg))
  10. return loss * 2
  11. def discriminator_loss(
  12. disc_real_outputs: list[torch.Tensor], disc_generated_outputs: list[torch.Tensor]
  13. ):
  14. loss = 0
  15. r_losses = []
  16. g_losses = []
  17. for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
  18. dr = dr.float()
  19. dg = dg.float()
  20. r_loss = torch.mean((1 - dr) ** 2)
  21. g_loss = torch.mean(dg**2)
  22. loss += r_loss + g_loss
  23. r_losses.append(r_loss.item())
  24. g_losses.append(g_loss.item())
  25. return loss, r_losses, g_losses
  26. def generator_loss(disc_outputs: list[torch.Tensor]):
  27. loss = 0
  28. gen_losses = []
  29. for dg in disc_outputs:
  30. dg = dg.float()
  31. l = torch.mean((1 - dg) ** 2)
  32. gen_losses.append(l)
  33. loss += l
  34. return loss, gen_losses
  35. def kl_loss(
  36. z_p: torch.Tensor,
  37. logs_q: torch.Tensor,
  38. m_p: torch.Tensor,
  39. logs_p: torch.Tensor,
  40. z_mask: torch.Tensor,
  41. ):
  42. """
  43. z_p, logs_q: [b, h, t_t]
  44. m_p, logs_p: [b, h, t_t]
  45. """
  46. z_p = z_p.float()
  47. logs_q = logs_q.float()
  48. m_p = m_p.float()
  49. logs_p = logs_p.float()
  50. z_mask = z_mask.float()
  51. kl = logs_p - logs_q - 0.5
  52. kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
  53. kl = torch.sum(kl * z_mask)
  54. l = kl / torch.sum(z_mask)
  55. return l
  56. def stft(x, fft_size, hop_size, win_length, window):
  57. """Perform STFT and convert to magnitude spectrogram.
  58. Args:
  59. x (Tensor): Input signal tensor (B, T).
  60. fft_size (int): FFT size.
  61. hop_size (int): Hop size.
  62. win_length (int): Window length.
  63. window (str): Window function type.
  64. Returns:
  65. Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
  66. """
  67. spec = torch.stft(
  68. x,
  69. fft_size,
  70. hop_size,
  71. win_length,
  72. window,
  73. return_complex=True,
  74. pad_mode="reflect",
  75. )
  76. spec = torch.view_as_real(spec)
  77. # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
  78. return torch.sqrt(torch.clamp(spec.pow(2).sum(-1), min=1e-6)).transpose(2, 1)
  79. class SpectralConvergengeLoss(nn.Module):
  80. """Spectral convergence loss module."""
  81. def __init__(self):
  82. """Initialize spectral convergence loss module."""
  83. super(SpectralConvergengeLoss, self).__init__()
  84. def forward(self, x_mag, y_mag):
  85. """Calculate forward propagation.
  86. Args:
  87. x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
  88. y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
  89. Returns:
  90. Tensor: Spectral convergence loss value.
  91. """ # noqa: E501
  92. return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
  93. class LogSTFTMagnitudeLoss(nn.Module):
  94. """Log STFT magnitude loss module."""
  95. def __init__(self):
  96. """Initialize los STFT magnitude loss module."""
  97. super(LogSTFTMagnitudeLoss, self).__init__()
  98. def forward(self, x_mag, y_mag):
  99. """Calculate forward propagation.
  100. Args:
  101. x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
  102. y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
  103. Returns:
  104. Tensor: Log STFT magnitude loss value.
  105. """ # noqa: E501
  106. return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
  107. class STFTLoss(nn.Module):
  108. """STFT loss module."""
  109. def __init__(
  110. self, fft_size=1024, shift_size=120, win_length=600, window=torch.hann_window
  111. ):
  112. """Initialize STFT loss module."""
  113. super(STFTLoss, self).__init__()
  114. self.fft_size = fft_size
  115. self.shift_size = shift_size
  116. self.win_length = win_length
  117. self.register_buffer("window", window(win_length))
  118. self.spectral_convergenge_loss = SpectralConvergengeLoss()
  119. self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
  120. def forward(self, x, y):
  121. """Calculate forward propagation.
  122. Args:
  123. x (Tensor): Predicted signal (B, T).
  124. y (Tensor): Groundtruth signal (B, T).
  125. Returns:
  126. Tensor: Spectral convergence loss value.
  127. Tensor: Log STFT magnitude loss value.
  128. """
  129. x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
  130. y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
  131. sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
  132. mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
  133. return sc_loss, mag_loss
  134. class MultiResolutionSTFTLoss(nn.Module):
  135. """Multi resolution STFT loss module."""
  136. def __init__(self, resolutions, window=torch.hann_window):
  137. super(MultiResolutionSTFTLoss, self).__init__()
  138. self.stft_losses = nn.ModuleList()
  139. for fs, ss, wl in resolutions:
  140. self.stft_losses += [STFTLoss(fs, ss, wl, window)]
  141. def forward(self, x, y):
  142. """Calculate forward propagation.
  143. Args:
  144. x (Tensor): Predicted signal (B, T).
  145. y (Tensor): Groundtruth signal (B, T).
  146. Returns:
  147. Tensor: Multi resolution spectral convergence loss value.
  148. Tensor: Multi resolution log STFT magnitude loss value.
  149. """
  150. sc_loss = 0.0
  151. mag_loss = 0.0
  152. for f in self.stft_losses:
  153. sc_l, mag_l = f(x, y)
  154. sc_loss += sc_l
  155. mag_loss += mag_l
  156. sc_loss /= len(self.stft_losses)
  157. mag_loss /= len(self.stft_losses)
  158. return sc_loss, mag_loss