glow.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. # *****************************************************************************
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Redistribution and use in source and binary forms, with or without
  5. # modification, are permitted provided that the following conditions are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of the NVIDIA CORPORATION nor the
  12. # names of its contributors may be used to endorse or promote products
  13. # derived from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  16. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  17. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  18. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  19. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  20. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  21. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  22. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  23. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  24. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  25. #
  26. # *****************************************************************************
  27. import copy
  28. import torch
  29. from torch.autograd import Variable
  30. import torch.nn.functional as F
  31. @torch.jit.script
  32. def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
  33. n_channels_int = n_channels[0]
  34. in_act = input_a+input_b
  35. t_act = torch.tanh(in_act[:, :n_channels_int, :])
  36. s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
  37. acts = t_act * s_act
  38. return acts
  39. class WaveGlowLoss(torch.nn.Module):
  40. def __init__(self, sigma=1.0):
  41. super(WaveGlowLoss, self).__init__()
  42. self.sigma = sigma
  43. def forward(self, model_output):
  44. z, log_s_list, log_det_W_list = model_output
  45. for i, log_s in enumerate(log_s_list):
  46. if i == 0:
  47. log_s_total = torch.sum(log_s)
  48. log_det_W_total = log_det_W_list[i]
  49. else:
  50. log_s_total = log_s_total + torch.sum(log_s)
  51. log_det_W_total += log_det_W_list[i]
  52. loss = torch.sum(z*z)/(2*self.sigma*self.sigma) - log_s_total - log_det_W_total
  53. return loss/(z.size(0)*z.size(1)*z.size(2))
  54. class Invertible1x1Conv(torch.nn.Module):
  55. """
  56. The layer outputs both the convolution, and the log determinant
  57. of its weight matrix. If reverse=True it does convolution with
  58. inverse
  59. """
  60. def __init__(self, c):
  61. super(Invertible1x1Conv, self).__init__()
  62. self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0,
  63. bias=False)
  64. # Sample a random orthonormal matrix to initialize weights
  65. W = torch.qr(torch.FloatTensor(c, c).normal_())[0]
  66. # Ensure determinant is 1.0 not -1.0
  67. if torch.det(W) < 0:
  68. W[:,0] = -1*W[:,0]
  69. W = W.view(c, c, 1)
  70. self.conv.weight.data = W
  71. def forward(self, z, reverse=False):
  72. # shape
  73. batch_size, group_size, n_of_groups = z.size()
  74. W = self.conv.weight.squeeze()
  75. if reverse:
  76. if not hasattr(self, 'W_inverse'):
  77. # Reverse computation
  78. W_inverse = W.float().inverse()
  79. W_inverse = Variable(W_inverse[..., None])
  80. if z.type() == 'torch.cuda.HalfTensor':
  81. W_inverse = W_inverse.half()
  82. self.W_inverse = W_inverse
  83. z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
  84. return z
  85. else:
  86. # Forward computation
  87. log_det_W = batch_size * n_of_groups * torch.logdet(W)
  88. z = self.conv(z)
  89. return z, log_det_W
  90. class WN(torch.nn.Module):
  91. """
  92. This is the WaveNet like layer for the affine coupling. The primary difference
  93. from WaveNet is the convolutions need not be causal. There is also no dilation
  94. size reset. The dilation only doubles on each layer
  95. """
  96. def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,
  97. kernel_size):
  98. super(WN, self).__init__()
  99. assert(kernel_size % 2 == 1)
  100. assert(n_channels % 2 == 0)
  101. self.n_layers = n_layers
  102. self.n_channels = n_channels
  103. self.in_layers = torch.nn.ModuleList()
  104. self.res_skip_layers = torch.nn.ModuleList()
  105. start = torch.nn.Conv1d(n_in_channels, n_channels, 1)
  106. start = torch.nn.utils.weight_norm(start, name='weight')
  107. self.start = start
  108. # Initializing last layer to 0 makes the affine coupling layers
  109. # do nothing at first. This helps with training stability
  110. end = torch.nn.Conv1d(n_channels, 2*n_in_channels, 1)
  111. end.weight.data.zero_()
  112. end.bias.data.zero_()
  113. self.end = end
  114. cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels*n_layers, 1)
  115. self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
  116. for i in range(n_layers):
  117. dilation = 2 ** i
  118. padding = int((kernel_size*dilation - dilation)/2)
  119. in_layer = torch.nn.Conv1d(n_channels, 2*n_channels, kernel_size,
  120. dilation=dilation, padding=padding)
  121. in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
  122. self.in_layers.append(in_layer)
  123. # last one is not necessary
  124. if i < n_layers - 1:
  125. res_skip_channels = 2*n_channels
  126. else:
  127. res_skip_channels = n_channels
  128. res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1)
  129. res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
  130. self.res_skip_layers.append(res_skip_layer)
  131. def forward(self, forward_input):
  132. audio, spect = forward_input
  133. audio = self.start(audio)
  134. output = torch.zeros_like(audio)
  135. n_channels_tensor = torch.IntTensor([self.n_channels])
  136. spect = self.cond_layer(spect)
  137. for i in range(self.n_layers):
  138. spect_offset = i*2*self.n_channels
  139. acts = fused_add_tanh_sigmoid_multiply(
  140. self.in_layers[i](audio),
  141. spect[:,spect_offset:spect_offset+2*self.n_channels,:],
  142. n_channels_tensor)
  143. res_skip_acts = self.res_skip_layers[i](acts)
  144. if i < self.n_layers - 1:
  145. audio = audio + res_skip_acts[:,:self.n_channels,:]
  146. output = output + res_skip_acts[:,self.n_channels:,:]
  147. else:
  148. output = output + res_skip_acts
  149. return self.end(output)
  150. class WaveGlow(torch.nn.Module):
  151. def __init__(self, n_mel_channels, n_flows, n_group, n_early_every,
  152. n_early_size, WN_config):
  153. super(WaveGlow, self).__init__()
  154. self.upsample = torch.nn.ConvTranspose1d(n_mel_channels,
  155. n_mel_channels,
  156. 1024, stride=256)
  157. assert(n_group % 2 == 0)
  158. self.n_flows = n_flows
  159. self.n_group = n_group
  160. self.n_early_every = n_early_every
  161. self.n_early_size = n_early_size
  162. self.WN = torch.nn.ModuleList()
  163. self.convinv = torch.nn.ModuleList()
  164. n_half = int(n_group/2)
  165. # Set up layers with the right sizes based on how many dimensions
  166. # have been output already
  167. n_remaining_channels = n_group
  168. for k in range(n_flows):
  169. if k % self.n_early_every == 0 and k > 0:
  170. n_half = n_half - int(self.n_early_size/2)
  171. n_remaining_channels = n_remaining_channels - self.n_early_size
  172. self.convinv.append(Invertible1x1Conv(n_remaining_channels))
  173. self.WN.append(WN(n_half, n_mel_channels*n_group, **WN_config))
  174. self.n_remaining_channels = n_remaining_channels # Useful during inference
  175. def forward(self, forward_input):
  176. """
  177. forward_input[0] = mel_spectrogram: batch x n_mel_channels x frames
  178. forward_input[1] = audio: batch x time
  179. """
  180. spect, audio = forward_input
  181. # Upsample spectrogram to size of audio
  182. spect = self.upsample(spect)
  183. assert(spect.size(2) >= audio.size(1))
  184. if spect.size(2) > audio.size(1):
  185. spect = spect[:, :, :audio.size(1)]
  186. spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
  187. spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
  188. audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1)
  189. output_audio = []
  190. log_s_list = []
  191. log_det_W_list = []
  192. for k in range(self.n_flows):
  193. if k % self.n_early_every == 0 and k > 0:
  194. output_audio.append(audio[:,:self.n_early_size,:])
  195. audio = audio[:,self.n_early_size:,:]
  196. audio, log_det_W = self.convinv[k](audio)
  197. log_det_W_list.append(log_det_W)
  198. n_half = int(audio.size(1)/2)
  199. audio_0 = audio[:,:n_half,:]
  200. audio_1 = audio[:,n_half:,:]
  201. output = self.WN[k]((audio_0, spect))
  202. log_s = output[:, n_half:, :]
  203. b = output[:, :n_half, :]
  204. audio_1 = torch.exp(log_s)*audio_1 + b
  205. log_s_list.append(log_s)
  206. audio = torch.cat([audio_0, audio_1],1)
  207. output_audio.append(audio)
  208. return torch.cat(output_audio,1), log_s_list, log_det_W_list
  209. def infer(self, spect, sigma=1.0):
  210. spect = self.upsample(spect)
  211. # trim conv artifacts. maybe pad spec to kernel multiple
  212. time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0]
  213. spect = spect[:, :, :-time_cutoff]
  214. spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
  215. spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
  216. if spect.type() == 'torch.cuda.HalfTensor':
  217. audio = torch.cuda.HalfTensor(spect.size(0),
  218. self.n_remaining_channels,
  219. spect.size(2)).normal_()
  220. else:
  221. # audio = torch.cuda.FloatTensor(spect.size(0),
  222. # self.n_remaining_channels,
  223. # spect.size(2)).normal_()
  224. audio = torch.FloatTensor(spect.size(0),
  225. self.n_remaining_channels,
  226. spect.size(2)).normal_()
  227. audio = torch.autograd.Variable(sigma*audio)
  228. for k in reversed(range(self.n_flows)):
  229. n_half = int(audio.size(1)/2)
  230. audio_0 = audio[:,:n_half,:]
  231. audio_1 = audio[:,n_half:,:]
  232. output = self.WN[k]((audio_0, spect))
  233. s = output[:, n_half:, :]
  234. b = output[:, :n_half, :]
  235. audio_1 = (audio_1 - b)/torch.exp(s)
  236. audio = torch.cat([audio_0, audio_1],1)
  237. audio = self.convinv[k](audio, reverse=True)
  238. if k % self.n_early_every == 0 and k > 0:
  239. if spect.type() == 'torch.cuda.HalfTensor':
  240. z = torch.cuda.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
  241. else:
  242. # z = torch.cuda.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
  243. z = torch.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
  244. audio = torch.cat((sigma*z, audio),1)
  245. audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data
  246. return audio
  247. @staticmethod
  248. def remove_weightnorm(model):
  249. waveglow = model
  250. for WN in waveglow.WN:
  251. WN.start = torch.nn.utils.remove_weight_norm(WN.start)
  252. WN.in_layers = remove(WN.in_layers)
  253. WN.cond_layer = torch.nn.utils.remove_weight_norm(WN.cond_layer)
  254. WN.res_skip_layers = remove(WN.res_skip_layers)
  255. return waveglow
  256. def remove(conv_list):
  257. new_conv_list = torch.nn.ModuleList()
  258. for old_conv in conv_list:
  259. old_conv = torch.nn.utils.remove_weight_norm(old_conv)
  260. new_conv_list.append(old_conv)
  261. return new_conv_list