losses.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import torch
  2. import torch.nn.functional as F
  3. from torch import nn
  4. def feature_loss(fmap_r: list[torch.Tensor], fmap_g: list[torch.Tensor]):
  5. loss = 0
  6. for dr, dg in zip(fmap_r, fmap_g):
  7. dr = dr.float().detach()
  8. dg = dg.float()
  9. loss += torch.mean(torch.abs(dr - dg))
  10. return loss * 2
  11. def discriminator_loss(
  12. disc_real_outputs: list[torch.Tensor], disc_generated_outputs: list[torch.Tensor]
  13. ):
  14. loss = 0
  15. r_losses = []
  16. g_losses = []
  17. for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
  18. dr = dr.float()
  19. dg = dg.float()
  20. r_loss = torch.mean((1 - dr) ** 2)
  21. g_loss = torch.mean(dg**2)
  22. loss += r_loss + g_loss
  23. r_losses.append(r_loss.item())
  24. g_losses.append(g_loss.item())
  25. return loss, r_losses, g_losses
  26. def generator_loss(disc_outputs: list[torch.Tensor]):
  27. loss = 0
  28. gen_losses = []
  29. for dg in disc_outputs:
  30. dg = dg.float()
  31. l = torch.mean((1 - dg) ** 2)
  32. gen_losses.append(l)
  33. loss += l
  34. return loss, gen_losses
  35. def kl_loss(
  36. z_p: torch.Tensor,
  37. logs_q: torch.Tensor,
  38. m_p: torch.Tensor,
  39. logs_p: torch.Tensor,
  40. z_mask: torch.Tensor,
  41. ):
  42. """
  43. z_p, logs_q: [b, h, t_t]
  44. m_p, logs_p: [b, h, t_t]
  45. """
  46. z_p = z_p.float()
  47. logs_q = logs_q.float()
  48. m_p = m_p.float()
  49. logs_p = logs_p.float()
  50. z_mask = z_mask.float()
  51. kl = logs_p - logs_q - 0.5
  52. kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
  53. kl = torch.sum(kl * z_mask)
  54. l = kl / torch.sum(z_mask)
  55. return l