modules.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # adapted from https://github.com/KinglittleQ/GST-Tacotron/blob/master/GST.py
  2. # MIT License
  3. #
  4. # Copyright (c) 2018 MagicGirl Sakura
  5. #
  6. # Permission is hereby granted, free of charge, to any person obtaining a copy
  7. # of this software and associated documentation files (the "Software"), to deal
  8. # in the Software without restriction, including without limitation the rights
  9. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  10. # copies of the Software, and to permit persons to whom the Software is
  11. # furnished to do so, subject to the following conditions:
  12. #
  13. # The above copyright notice and this permission notice shall be included in
  14. # all copies or substantial portions of the Software.
  15. #
  16. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  17. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  18. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
  19. # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  20. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
  21. # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
  22. # DEALINGS IN THE SOFTWARE.
  23. import torch
  24. import torch.nn as nn
  25. import torch.nn.init as init
  26. import torch.nn.functional as F
  27. class ReferenceEncoder(nn.Module):
  28. '''
  29. inputs --- [N, Ty/r, n_mels*r] mels
  30. outputs --- [N, ref_enc_gru_size]
  31. '''
  32. def __init__(self, hp):
  33. super().__init__()
  34. K = len(hp.ref_enc_filters)
  35. filters = [1] + hp.ref_enc_filters
  36. convs = [nn.Conv2d(in_channels=filters[i],
  37. out_channels=filters[i + 1],
  38. kernel_size=(3, 3),
  39. stride=(2, 2),
  40. padding=(1, 1)) for i in range(K)]
  41. self.convs = nn.ModuleList(convs)
  42. self.bns = nn.ModuleList(
  43. [nn.BatchNorm2d(num_features=hp.ref_enc_filters[i])
  44. for i in range(K)])
  45. out_channels = self.calculate_channels(hp.n_mel_channels, 3, 2, 1, K)
  46. self.gru = nn.GRU(input_size=hp.ref_enc_filters[-1] * out_channels,
  47. hidden_size=hp.ref_enc_gru_size,
  48. batch_first=True)
  49. self.n_mel_channels = hp.n_mel_channels
  50. self.ref_enc_gru_size = hp.ref_enc_gru_size
  51. def forward(self, inputs, input_lengths=None):
  52. out = inputs.view(inputs.size(0), 1, -1, self.n_mel_channels)
  53. for conv, bn in zip(self.convs, self.bns):
  54. out = conv(out)
  55. out = bn(out)
  56. out = F.relu(out)
  57. out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
  58. N, T = out.size(0), out.size(1)
  59. out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
  60. if input_lengths is not None:
  61. input_lengths = torch.ceil(input_lengths.float() / 2 ** len(self.convs))
  62. input_lengths = input_lengths.cpu().numpy().astype(int)
  63. out = nn.utils.rnn.pack_padded_sequence(
  64. out, input_lengths, batch_first=True, enforce_sorted=False)
  65. self.gru.flatten_parameters()
  66. _, out = self.gru(out)
  67. return out.squeeze(0)
  68. def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
  69. for _ in range(n_convs):
  70. L = (L - kernel_size + 2 * pad) // stride + 1
  71. return L
  72. class STL(nn.Module):
  73. '''
  74. inputs --- [N, token_embedding_size//2]
  75. '''
  76. def __init__(self, hp):
  77. super().__init__()
  78. self.embed = nn.Parameter(torch.FloatTensor(hp.token_num, hp.token_embedding_size // hp.num_heads))
  79. d_q = hp.ref_enc_gru_size
  80. d_k = hp.token_embedding_size // hp.num_heads
  81. self.attention = MultiHeadAttention(
  82. query_dim=d_q, key_dim=d_k, num_units=hp.token_embedding_size,
  83. num_heads=hp.num_heads)
  84. init.normal_(self.embed, mean=0, std=0.5)
  85. def forward(self, inputs):
  86. N = inputs.size(0)
  87. query = inputs.unsqueeze(1)
  88. keys = torch.tanh(self.embed).unsqueeze(0).expand(N, -1, -1) # [N, token_num, token_embedding_size // num_heads]
  89. style_embed = self.attention(query, keys)
  90. return style_embed
  91. class MultiHeadAttention(nn.Module):
  92. '''
  93. input:
  94. query --- [N, T_q, query_dim]
  95. key --- [N, T_k, key_dim]
  96. output:
  97. out --- [N, T_q, num_units]
  98. '''
  99. def __init__(self, query_dim, key_dim, num_units, num_heads):
  100. super().__init__()
  101. self.num_units = num_units
  102. self.num_heads = num_heads
  103. self.key_dim = key_dim
  104. self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False)
  105. self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
  106. self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
  107. def forward(self, query, key):
  108. querys = self.W_query(query) # [N, T_q, num_units]
  109. keys = self.W_key(key) # [N, T_k, num_units]
  110. values = self.W_value(key)
  111. split_size = self.num_units // self.num_heads
  112. querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h]
  113. keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
  114. values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
  115. # score = softmax(QK^T / (d_k ** 0.5))
  116. scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
  117. scores = scores / (self.key_dim ** 0.5)
  118. scores = F.softmax(scores, dim=3)
  119. # out = score * V
  120. out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
  121. out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units]
  122. return out
  123. class GST(nn.Module):
  124. def __init__(self, hp):
  125. super().__init__()
  126. self.encoder = ReferenceEncoder(hp)
  127. self.stl = STL(hp)
  128. def forward(self, inputs, input_lengths=None):
  129. enc_out = self.encoder(inputs, input_lengths=input_lengths)
  130. style_embed = self.stl(enc_out)
  131. return style_embed