glow_old.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. import copy
  2. import torch
  3. from glow import Invertible1x1Conv, remove
  4. @torch.jit.script
  5. def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
  6. n_channels_int = n_channels[0]
  7. in_act = input_a+input_b
  8. t_act = torch.tanh(in_act[:, :n_channels_int, :])
  9. s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
  10. acts = t_act * s_act
  11. return acts
  12. class WN(torch.nn.Module):
  13. """
  14. This is the WaveNet like layer for the affine coupling. The primary difference
  15. from WaveNet is the convolutions need not be causal. There is also no dilation
  16. size reset. The dilation only doubles on each layer
  17. """
  18. def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,
  19. kernel_size):
  20. super(WN, self).__init__()
  21. assert(kernel_size % 2 == 1)
  22. assert(n_channels % 2 == 0)
  23. self.n_layers = n_layers
  24. self.n_channels = n_channels
  25. self.in_layers = torch.nn.ModuleList()
  26. self.res_skip_layers = torch.nn.ModuleList()
  27. self.cond_layers = torch.nn.ModuleList()
  28. start = torch.nn.Conv1d(n_in_channels, n_channels, 1)
  29. start = torch.nn.utils.weight_norm(start, name='weight')
  30. self.start = start
  31. # Initializing last layer to 0 makes the affine coupling layers
  32. # do nothing at first. This helps with training stability
  33. end = torch.nn.Conv1d(n_channels, 2*n_in_channels, 1)
  34. end.weight.data.zero_()
  35. end.bias.data.zero_()
  36. self.end = end
  37. for i in range(n_layers):
  38. dilation = 2 ** i
  39. padding = int((kernel_size*dilation - dilation)/2)
  40. in_layer = torch.nn.Conv1d(n_channels, 2*n_channels, kernel_size,
  41. dilation=dilation, padding=padding)
  42. in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
  43. self.in_layers.append(in_layer)
  44. cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels, 1)
  45. cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
  46. self.cond_layers.append(cond_layer)
  47. # last one is not necessary
  48. if i < n_layers - 1:
  49. res_skip_channels = 2*n_channels
  50. else:
  51. res_skip_channels = n_channels
  52. res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1)
  53. res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
  54. self.res_skip_layers.append(res_skip_layer)
  55. def forward(self, forward_input):
  56. audio, spect = forward_input
  57. audio = self.start(audio)
  58. for i in range(self.n_layers):
  59. acts = fused_add_tanh_sigmoid_multiply(
  60. self.in_layers[i](audio),
  61. self.cond_layers[i](spect),
  62. torch.IntTensor([self.n_channels]))
  63. res_skip_acts = self.res_skip_layers[i](acts)
  64. if i < self.n_layers - 1:
  65. audio = res_skip_acts[:,:self.n_channels,:] + audio
  66. skip_acts = res_skip_acts[:,self.n_channels:,:]
  67. else:
  68. skip_acts = res_skip_acts
  69. if i == 0:
  70. output = skip_acts
  71. else:
  72. output = skip_acts + output
  73. return self.end(output)
  74. class WaveGlow(torch.nn.Module):
  75. def __init__(self, n_mel_channels, n_flows, n_group, n_early_every,
  76. n_early_size, WN_config):
  77. super(WaveGlow, self).__init__()
  78. self.upsample = torch.nn.ConvTranspose1d(n_mel_channels,
  79. n_mel_channels,
  80. 1024, stride=256)
  81. assert(n_group % 2 == 0)
  82. self.n_flows = n_flows
  83. self.n_group = n_group
  84. self.n_early_every = n_early_every
  85. self.n_early_size = n_early_size
  86. self.WN = torch.nn.ModuleList()
  87. self.convinv = torch.nn.ModuleList()
  88. n_half = int(n_group/2)
  89. # Set up layers with the right sizes based on how many dimensions
  90. # have been output already
  91. n_remaining_channels = n_group
  92. for k in range(n_flows):
  93. if k % self.n_early_every == 0 and k > 0:
  94. n_half = n_half - int(self.n_early_size/2)
  95. n_remaining_channels = n_remaining_channels - self.n_early_size
  96. self.convinv.append(Invertible1x1Conv(n_remaining_channels))
  97. self.WN.append(WN(n_half, n_mel_channels*n_group, **WN_config))
  98. self.n_remaining_channels = n_remaining_channels # Useful during inference
  99. def forward(self, forward_input):
  100. return None
  101. """
  102. forward_input[0] = audio: batch x time
  103. forward_input[1] = upsamp_spectrogram: batch x n_cond_channels x time
  104. """
  105. """
  106. spect, audio = forward_input
  107. # Upsample spectrogram to size of audio
  108. spect = self.upsample(spect)
  109. assert(spect.size(2) >= audio.size(1))
  110. if spect.size(2) > audio.size(1):
  111. spect = spect[:, :, :audio.size(1)]
  112. spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
  113. spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
  114. audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1)
  115. output_audio = []
  116. s_list = []
  117. s_conv_list = []
  118. for k in range(self.n_flows):
  119. if k%4 == 0 and k > 0:
  120. output_audio.append(audio[:,:self.n_multi,:])
  121. audio = audio[:,self.n_multi:,:]
  122. # project to new basis
  123. audio, s = self.convinv[k](audio)
  124. s_conv_list.append(s)
  125. n_half = int(audio.size(1)/2)
  126. if k%2 == 0:
  127. audio_0 = audio[:,:n_half,:]
  128. audio_1 = audio[:,n_half:,:]
  129. else:
  130. audio_1 = audio[:,:n_half,:]
  131. audio_0 = audio[:,n_half:,:]
  132. output = self.nn[k]((audio_0, spect))
  133. s = output[:, n_half:, :]
  134. b = output[:, :n_half, :]
  135. audio_1 = torch.exp(s)*audio_1 + b
  136. s_list.append(s)
  137. if k%2 == 0:
  138. audio = torch.cat([audio[:,:n_half,:], audio_1],1)
  139. else:
  140. audio = torch.cat([audio_1, audio[:,n_half:,:]], 1)
  141. output_audio.append(audio)
  142. return torch.cat(output_audio,1), s_list, s_conv_list
  143. """
  144. def infer(self, spect, sigma=1.0):
  145. spect = self.upsample(spect)
  146. # trim conv artifacts. maybe pad spec to kernel multiple
  147. time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0]
  148. spect = spect[:, :, :-time_cutoff]
  149. spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
  150. spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
  151. if spect.type() == 'torch.cuda.HalfTensor':
  152. audio = torch.cuda.HalfTensor(spect.size(0),
  153. self.n_remaining_channels,
  154. spect.size(2)).normal_()
  155. else:
  156. audio = torch.cuda.FloatTensor(spect.size(0),
  157. self.n_remaining_channels,
  158. spect.size(2)).normal_()
  159. audio = torch.autograd.Variable(sigma*audio)
  160. for k in reversed(range(self.n_flows)):
  161. n_half = int(audio.size(1)/2)
  162. if k%2 == 0:
  163. audio_0 = audio[:,:n_half,:]
  164. audio_1 = audio[:,n_half:,:]
  165. else:
  166. audio_1 = audio[:,:n_half,:]
  167. audio_0 = audio[:,n_half:,:]
  168. output = self.WN[k]((audio_0, spect))
  169. s = output[:, n_half:, :]
  170. b = output[:, :n_half, :]
  171. audio_1 = (audio_1 - b)/torch.exp(s)
  172. if k%2 == 0:
  173. audio = torch.cat([audio[:,:n_half,:], audio_1],1)
  174. else:
  175. audio = torch.cat([audio_1, audio[:,n_half:,:]], 1)
  176. audio = self.convinv[k](audio, reverse=True)
  177. if k%4 == 0 and k > 0:
  178. if spect.type() == 'torch.cuda.HalfTensor':
  179. z = torch.cuda.HalfTensor(spect.size(0),
  180. self.n_early_size,
  181. spect.size(2)).normal_()
  182. else:
  183. z = torch.cuda.FloatTensor(spect.size(0),
  184. self.n_early_size,
  185. spect.size(2)).normal_()
  186. audio = torch.cat((sigma*z, audio),1)
  187. return audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data
  188. @staticmethod
  189. def remove_weightnorm(model):
  190. waveglow = model
  191. for WN in waveglow.WN:
  192. WN.start = torch.nn.utils.remove_weight_norm(WN.start)
  193. WN.in_layers = remove(WN.in_layers)
  194. WN.cond_layers = remove(WN.cond_layers)
  195. WN.res_skip_layers = remove(WN.res_skip_layers)
  196. return waveglow