condition.py 1.0 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. import torch
  2. import torch.nn as nn
  3. class MultiCondLayer(nn.Module):
  4. def __init__(
  5. self,
  6. gin_channels: int,
  7. out_channels: int,
  8. n_cond: int,
  9. ):
  10. """MultiCondLayer of VITS model.
  11. Args:
  12. gin_channels (int): Number of conditioning tensor channels.
  13. out_channels (int): Number of output tensor channels.
  14. n_cond (int): Number of conditions.
  15. """
  16. super().__init__()
  17. self.n_cond = n_cond
  18. self.cond_layers = nn.ModuleList()
  19. for _ in range(n_cond):
  20. self.cond_layers.append(nn.Linear(gin_channels, out_channels))
  21. def forward(self, cond: torch.Tensor, x_mask: torch.Tensor):
  22. """
  23. Shapes:
  24. - cond: :math:`[B, C, N]`
  25. - x_mask: :math`[B, 1, T]`
  26. """
  27. cond_out = torch.zeros_like(cond)
  28. for i in range(self.n_cond):
  29. cond_in = self.cond_layers[i](cond.mT).mT
  30. cond_out = cond_out + cond_in
  31. return cond_out * x_mask