mrte.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import torch
  2. from torch import nn
  3. from torch.nn.utils import remove_weight_norm, weight_norm
  4. from fish_speech.models.vits_decoder.modules.attentions import MultiHeadAttention
  5. class MRTE(nn.Module):
  6. def __init__(
  7. self,
  8. content_enc_channels=192,
  9. hidden_size=512,
  10. out_channels=192,
  11. n_heads=4,
  12. ):
  13. super(MRTE, self).__init__()
  14. self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads)
  15. self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
  16. self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
  17. self.c_post = nn.Conv1d(hidden_size, out_channels, 1)
  18. def forward(self, ssl_enc, ssl_mask, text, text_mask, ge, test=None):
  19. if ge == None:
  20. ge = 0
  21. attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1)
  22. ssl_enc = self.c_pre(ssl_enc * ssl_mask)
  23. text_enc = self.text_pre(text * text_mask)
  24. if test != None:
  25. if test == 0:
  26. x = (
  27. self.cross_attention(
  28. ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
  29. )
  30. + ssl_enc
  31. + ge
  32. )
  33. elif test == 1:
  34. x = ssl_enc + ge
  35. elif test == 2:
  36. x = (
  37. self.cross_attention(
  38. ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask
  39. )
  40. + ge
  41. )
  42. else:
  43. raise ValueError("test should be 0,1,2")
  44. else:
  45. x = (
  46. self.cross_attention(
  47. ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
  48. )
  49. + ssl_enc
  50. + ge
  51. )
  52. x = self.c_post(x * ssl_mask)
  53. return x