mrte.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import torch
  2. from torch import nn
  3. from torch.nn.utils import remove_weight_norm, weight_norm
  4. from fish_speech.models.vits_decoder.modules.attentions import MultiHeadAttention
  5. class MRTE(nn.Module):
  6. def __init__(
  7. self,
  8. content_enc_channels=192,
  9. hidden_size=512,
  10. out_channels=192,
  11. kernel_size=5,
  12. n_heads=4,
  13. ge_layer=2,
  14. ):
  15. super(MRTE, self).__init__()
  16. self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads)
  17. self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
  18. self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
  19. self.c_post = nn.Conv1d(hidden_size, out_channels, 1)
  20. def forward(self, ssl_enc, ssl_mask, text, text_mask, ge, test=None):
  21. if ge == None:
  22. ge = 0
  23. attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1)
  24. ssl_enc = self.c_pre(ssl_enc * ssl_mask)
  25. text_enc = self.text_pre(text * text_mask)
  26. if test != None:
  27. if test == 0:
  28. x = (
  29. self.cross_attention(
  30. ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
  31. )
  32. + ssl_enc
  33. + ge
  34. )
  35. elif test == 1:
  36. x = ssl_enc + ge
  37. elif test == 2:
  38. x = (
  39. self.cross_attention(
  40. ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask
  41. )
  42. + ge
  43. )
  44. else:
  45. raise ValueError("test should be 0,1,2")
  46. else:
  47. x = (
  48. self.cross_attention(
  49. ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
  50. )
  51. + ssl_enc
  52. + ge
  53. )
  54. x = self.c_post(x * ssl_mask)
  55. return x