model.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694
  1. from math import sqrt
  2. import numpy as np
  3. from numpy import finfo
  4. import torch
  5. from torch.autograd import Variable
  6. from torch import nn
  7. from torch.nn import functional as F
  8. from layers import ConvNorm, LinearNorm
  9. from utils import to_gpu, get_mask_from_lengths
  10. from modules import GST
  11. drop_rate = 0.5
  12. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  13. def load_model(hparams):
  14. # model = Tacotron2(hparams).cuda()
  15. model = Tacotron2(hparams).to(device)
  16. if hparams.fp16_run:
  17. model.decoder.attention_layer.score_mask_value = finfo('float16').min
  18. return model
  19. class LocationLayer(nn.Module):
  20. def __init__(self, attention_n_filters, attention_kernel_size,
  21. attention_dim):
  22. super(LocationLayer, self).__init__()
  23. padding = int((attention_kernel_size - 1) / 2)
  24. self.location_conv = ConvNorm(2, attention_n_filters,
  25. kernel_size=attention_kernel_size,
  26. padding=padding, bias=False, stride=1,
  27. dilation=1)
  28. self.location_dense = LinearNorm(attention_n_filters, attention_dim,
  29. bias=False, w_init_gain='tanh')
  30. def forward(self, attention_weights_cat):
  31. processed_attention = self.location_conv(attention_weights_cat)
  32. processed_attention = processed_attention.transpose(1, 2)
  33. processed_attention = self.location_dense(processed_attention)
  34. return processed_attention
  35. class Attention(nn.Module):
  36. def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
  37. attention_location_n_filters, attention_location_kernel_size):
  38. super(Attention, self).__init__()
  39. self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
  40. bias=False, w_init_gain='tanh')
  41. self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
  42. w_init_gain='tanh')
  43. self.v = LinearNorm(attention_dim, 1, bias=False)
  44. self.location_layer = LocationLayer(attention_location_n_filters,
  45. attention_location_kernel_size,
  46. attention_dim)
  47. self.score_mask_value = -float("inf")
  48. def get_alignment_energies(self, query, processed_memory,
  49. attention_weights_cat):
  50. """
  51. PARAMS
  52. ------
  53. query: decoder output (batch, n_mel_channels * n_frames_per_step)
  54. processed_memory: processed encoder outputs (B, T_in, attention_dim)
  55. attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
  56. RETURNS
  57. -------
  58. alignment (batch, max_time)
  59. """
  60. processed_query = self.query_layer(query.unsqueeze(1))
  61. processed_attention_weights = self.location_layer(attention_weights_cat)
  62. energies = self.v(torch.tanh(
  63. processed_query + processed_attention_weights + processed_memory))
  64. energies = energies.squeeze(-1)
  65. return energies
  66. def forward(self, attention_hidden_state, memory, processed_memory,
  67. attention_weights_cat, mask, attention_weights=None):
  68. """
  69. PARAMS
  70. ------
  71. attention_hidden_state: attention rnn last output
  72. memory: encoder outputs
  73. processed_memory: processed encoder outputs
  74. attention_weights_cat: previous and cummulative attention weights
  75. mask: binary mask for padded data
  76. """
  77. if attention_weights is None:
  78. alignment = self.get_alignment_energies(
  79. attention_hidden_state, processed_memory, attention_weights_cat)
  80. if mask is not None:
  81. alignment.data.masked_fill_(mask, self.score_mask_value)
  82. attention_weights = F.softmax(alignment, dim=1)
  83. attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
  84. attention_context = attention_context.squeeze(1)
  85. return attention_context, attention_weights
  86. class Prenet(nn.Module):
  87. def __init__(self, in_dim, sizes):
  88. super(Prenet, self).__init__()
  89. in_sizes = [in_dim] + sizes[:-1]
  90. self.layers = nn.ModuleList(
  91. [LinearNorm(in_size, out_size, bias=False)
  92. for (in_size, out_size) in zip(in_sizes, sizes)])
  93. def forward(self, x):
  94. for linear in self.layers:
  95. x = F.dropout(F.relu(linear(x)), p=drop_rate, training=True)
  96. return x
  97. class Postnet(nn.Module):
  98. """Postnet
  99. - Five 1-d convolution with 512 channels and kernel size 5
  100. """
  101. def __init__(self, hparams):
  102. super(Postnet, self).__init__()
  103. self.convolutions = nn.ModuleList()
  104. self.convolutions.append(
  105. nn.Sequential(
  106. ConvNorm(hparams.n_mel_channels, hparams.postnet_embedding_dim,
  107. kernel_size=hparams.postnet_kernel_size, stride=1,
  108. padding=int((hparams.postnet_kernel_size - 1) / 2),
  109. dilation=1, w_init_gain='tanh'),
  110. nn.BatchNorm1d(hparams.postnet_embedding_dim))
  111. )
  112. for i in range(1, hparams.postnet_n_convolutions - 1):
  113. self.convolutions.append(
  114. nn.Sequential(
  115. ConvNorm(hparams.postnet_embedding_dim,
  116. hparams.postnet_embedding_dim,
  117. kernel_size=hparams.postnet_kernel_size, stride=1,
  118. padding=int((hparams.postnet_kernel_size - 1) / 2),
  119. dilation=1, w_init_gain='tanh'),
  120. nn.BatchNorm1d(hparams.postnet_embedding_dim))
  121. )
  122. self.convolutions.append(
  123. nn.Sequential(
  124. ConvNorm(hparams.postnet_embedding_dim, hparams.n_mel_channels,
  125. kernel_size=hparams.postnet_kernel_size, stride=1,
  126. padding=int((hparams.postnet_kernel_size - 1) / 2),
  127. dilation=1, w_init_gain='linear'),
  128. nn.BatchNorm1d(hparams.n_mel_channels))
  129. )
  130. def forward(self, x):
  131. for i in range(len(self.convolutions) - 1):
  132. x = F.dropout(torch.tanh(self.convolutions[i](x)), drop_rate, self.training)
  133. x = F.dropout(self.convolutions[-1](x), drop_rate, self.training)
  134. return x
  135. class Encoder(nn.Module):
  136. """Encoder module:
  137. - Three 1-d convolution banks
  138. - Bidirectional LSTM
  139. """
  140. def __init__(self, hparams):
  141. super(Encoder, self).__init__()
  142. convolutions = []
  143. for _ in range(hparams.encoder_n_convolutions):
  144. conv_layer = nn.Sequential(
  145. ConvNorm(hparams.encoder_embedding_dim,
  146. hparams.encoder_embedding_dim,
  147. kernel_size=hparams.encoder_kernel_size, stride=1,
  148. padding=int((hparams.encoder_kernel_size - 1) / 2),
  149. dilation=1, w_init_gain='relu'),
  150. nn.BatchNorm1d(hparams.encoder_embedding_dim))
  151. convolutions.append(conv_layer)
  152. self.convolutions = nn.ModuleList(convolutions)
  153. self.lstm = nn.LSTM(hparams.encoder_embedding_dim,
  154. int(hparams.encoder_embedding_dim / 2), 1,
  155. batch_first=True, bidirectional=True)
  156. def forward(self, x, input_lengths):
  157. if x.size()[0] > 1:
  158. print("here")
  159. x_embedded = []
  160. for b_ind in range(x.size()[0]): # TODO: Speed up
  161. curr_x = x[b_ind:b_ind+1, :, :input_lengths[b_ind]].clone()
  162. for conv in self.convolutions:
  163. curr_x = F.dropout(F.relu(conv(curr_x)), drop_rate, self.training)
  164. x_embedded.append(curr_x[0].transpose(0, 1))
  165. x = torch.nn.utils.rnn.pad_sequence(x_embedded, batch_first=True)
  166. else:
  167. for conv in self.convolutions:
  168. x = F.dropout(F.relu(conv(x)), drop_rate, self.training)
  169. x = x.transpose(1, 2)
  170. # pytorch tensor are not reversible, hence the conversion
  171. input_lengths = input_lengths.cpu().numpy()
  172. x = nn.utils.rnn.pack_padded_sequence(
  173. x, input_lengths, batch_first=True)
  174. self.lstm.flatten_parameters()
  175. outputs, _ = self.lstm(x)
  176. outputs, _ = nn.utils.rnn.pad_packed_sequence(
  177. outputs, batch_first=True)
  178. return outputs
  179. def inference(self, x):
  180. for conv in self.convolutions:
  181. x = F.dropout(F.relu(conv(x)), drop_rate, self.training)
  182. x = x.transpose(1, 2)
  183. self.lstm.flatten_parameters()
  184. outputs, _ = self.lstm(x)
  185. return outputs
  186. class Decoder(nn.Module):
  187. def __init__(self, hparams):
  188. super(Decoder, self).__init__()
  189. self.n_mel_channels = hparams.n_mel_channels
  190. self.n_frames_per_step = hparams.n_frames_per_step
  191. self.encoder_embedding_dim = hparams.encoder_embedding_dim + hparams.token_embedding_size + hparams.speaker_embedding_dim
  192. self.attention_rnn_dim = hparams.attention_rnn_dim
  193. self.decoder_rnn_dim = hparams.decoder_rnn_dim
  194. self.prenet_dim = hparams.prenet_dim
  195. self.max_decoder_steps = hparams.max_decoder_steps
  196. self.gate_threshold = hparams.gate_threshold
  197. self.p_attention_dropout = hparams.p_attention_dropout
  198. self.p_decoder_dropout = hparams.p_decoder_dropout
  199. self.p_teacher_forcing = hparams.p_teacher_forcing
  200. self.prenet_f0 = ConvNorm(
  201. 1, hparams.prenet_f0_dim,
  202. kernel_size=hparams.prenet_f0_kernel_size,
  203. padding=max(0, int(hparams.prenet_f0_kernel_size/2)),
  204. bias=False, stride=1, dilation=1)
  205. self.prenet = Prenet(
  206. hparams.n_mel_channels * hparams.n_frames_per_step,
  207. [hparams.prenet_dim, hparams.prenet_dim])
  208. self.attention_rnn = nn.LSTMCell(
  209. hparams.prenet_dim + hparams.prenet_f0_dim + self.encoder_embedding_dim,
  210. hparams.attention_rnn_dim)
  211. self.attention_layer = Attention(
  212. hparams.attention_rnn_dim, self.encoder_embedding_dim,
  213. hparams.attention_dim, hparams.attention_location_n_filters,
  214. hparams.attention_location_kernel_size)
  215. self.decoder_rnn = nn.LSTMCell(
  216. hparams.attention_rnn_dim + self.encoder_embedding_dim,
  217. hparams.decoder_rnn_dim, 1)
  218. self.linear_projection = LinearNorm(
  219. hparams.decoder_rnn_dim + self.encoder_embedding_dim,
  220. hparams.n_mel_channels * hparams.n_frames_per_step)
  221. self.gate_layer = LinearNorm(
  222. hparams.decoder_rnn_dim + self.encoder_embedding_dim, 1,
  223. bias=True, w_init_gain='sigmoid')
  224. def get_go_frame(self, memory):
  225. """ Gets all zeros frames to use as first decoder input
  226. PARAMS
  227. ------
  228. memory: decoder outputs
  229. RETURNS
  230. -------
  231. decoder_input: all zeros frames
  232. """
  233. B = memory.size(0)
  234. decoder_input = Variable(memory.data.new(
  235. B, self.n_mel_channels * self.n_frames_per_step).zero_())
  236. return decoder_input
  237. def get_end_f0(self, f0s):
  238. B = f0s.size(0)
  239. dummy = Variable(f0s.data.new(B, 1, f0s.size(1)).zero_())
  240. return dummy
  241. def initialize_decoder_states(self, memory, mask):
  242. """ Initializes attention rnn states, decoder rnn states, attention
  243. weights, attention cumulative weights, attention context, stores memory
  244. and stores processed memory
  245. PARAMS
  246. ------
  247. memory: Encoder outputs
  248. mask: Mask for padded data if training, expects None for inference
  249. """
  250. B = memory.size(0)
  251. MAX_TIME = memory.size(1)
  252. self.attention_hidden = Variable(memory.data.new(
  253. B, self.attention_rnn_dim).zero_())
  254. self.attention_cell = Variable(memory.data.new(
  255. B, self.attention_rnn_dim).zero_())
  256. self.decoder_hidden = Variable(memory.data.new(
  257. B, self.decoder_rnn_dim).zero_())
  258. self.decoder_cell = Variable(memory.data.new(
  259. B, self.decoder_rnn_dim).zero_())
  260. self.attention_weights = Variable(memory.data.new(
  261. B, MAX_TIME).zero_())
  262. self.attention_weights_cum = Variable(memory.data.new(
  263. B, MAX_TIME).zero_())
  264. self.attention_context = Variable(memory.data.new(
  265. B, self.encoder_embedding_dim).zero_())
  266. self.memory = memory
  267. self.processed_memory = self.attention_layer.memory_layer(memory)
  268. self.mask = mask
  269. def parse_decoder_inputs(self, decoder_inputs):
  270. """ Prepares decoder inputs, i.e. mel outputs
  271. PARAMS
  272. ------
  273. decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs
  274. RETURNS
  275. -------
  276. inputs: processed decoder inputs
  277. """
  278. # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels)
  279. decoder_inputs = decoder_inputs.transpose(1, 2)
  280. decoder_inputs = decoder_inputs.view(
  281. decoder_inputs.size(0),
  282. int(decoder_inputs.size(1)/self.n_frames_per_step), -1)
  283. # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels)
  284. decoder_inputs = decoder_inputs.transpose(0, 1)
  285. return decoder_inputs
  286. def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments):
  287. """ Prepares decoder outputs for output
  288. PARAMS
  289. ------
  290. mel_outputs:
  291. gate_outputs: gate output energies
  292. alignments:
  293. RETURNS
  294. -------
  295. mel_outputs:
  296. gate_outpust: gate output energies
  297. alignments:
  298. """
  299. # (T_out, B) -> (B, T_out)
  300. alignments = torch.stack(alignments).transpose(0, 1)
  301. # (T_out, B) -> (B, T_out)
  302. gate_outputs = torch.stack(gate_outputs)
  303. if len(gate_outputs.size()) > 1:
  304. gate_outputs = gate_outputs.transpose(0, 1)
  305. else:
  306. gate_outputs = gate_outputs[None]
  307. gate_outputs = gate_outputs.contiguous()
  308. # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels)
  309. mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
  310. # decouple frames per step
  311. mel_outputs = mel_outputs.view(
  312. mel_outputs.size(0), -1, self.n_mel_channels)
  313. # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out)
  314. mel_outputs = mel_outputs.transpose(1, 2)
  315. return mel_outputs, gate_outputs, alignments
  316. def decode(self, decoder_input, attention_weights=None):
  317. """ Decoder step using stored states, attention and memory
  318. PARAMS
  319. ------
  320. decoder_input: previous mel output
  321. RETURNS
  322. -------
  323. mel_output:
  324. gate_output: gate output energies
  325. attention_weights:
  326. """
  327. cell_input = torch.cat((decoder_input, self.attention_context), -1)
  328. self.attention_hidden, self.attention_cell = self.attention_rnn(
  329. cell_input, (self.attention_hidden, self.attention_cell))
  330. self.attention_hidden = F.dropout(
  331. self.attention_hidden, self.p_attention_dropout, self.training)
  332. self.attention_cell = F.dropout(
  333. self.attention_cell, self.p_attention_dropout, self.training)
  334. attention_weights_cat = torch.cat(
  335. (self.attention_weights.unsqueeze(1),
  336. self.attention_weights_cum.unsqueeze(1)), dim=1)
  337. self.attention_context, self.attention_weights = self.attention_layer(
  338. self.attention_hidden, self.memory, self.processed_memory,
  339. attention_weights_cat, self.mask, attention_weights)
  340. self.attention_weights_cum += self.attention_weights
  341. decoder_input = torch.cat(
  342. (self.attention_hidden, self.attention_context), -1)
  343. self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
  344. decoder_input, (self.decoder_hidden, self.decoder_cell))
  345. self.decoder_hidden = F.dropout(
  346. self.decoder_hidden, self.p_decoder_dropout, self.training)
  347. self.decoder_cell = F.dropout(
  348. self.decoder_cell, self.p_decoder_dropout, self.training)
  349. decoder_hidden_attention_context = torch.cat(
  350. (self.decoder_hidden, self.attention_context), dim=1)
  351. decoder_output = self.linear_projection(
  352. decoder_hidden_attention_context)
  353. gate_prediction = self.gate_layer(decoder_hidden_attention_context)
  354. return decoder_output, gate_prediction, self.attention_weights
  355. def forward(self, memory, decoder_inputs, memory_lengths, f0s):
  356. """ Decoder forward pass for training
  357. PARAMS
  358. ------
  359. memory: Encoder outputs
  360. decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
  361. memory_lengths: Encoder output lengths for attention masking.
  362. RETURNS
  363. -------
  364. mel_outputs: mel outputs from the decoder
  365. gate_outputs: gate outputs from the decoder
  366. alignments: sequence of attention weights from the decoder
  367. """
  368. decoder_input = self.get_go_frame(memory).unsqueeze(0)
  369. decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
  370. decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
  371. decoder_inputs = self.prenet(decoder_inputs)
  372. # audio features
  373. f0_dummy = self.get_end_f0(f0s)
  374. f0s = torch.cat((f0s, f0_dummy), dim=2)
  375. f0s = F.relu(self.prenet_f0(f0s))
  376. f0s = f0s.permute(2, 0, 1)
  377. self.initialize_decoder_states(
  378. memory, mask=~get_mask_from_lengths(memory_lengths))
  379. mel_outputs, gate_outputs, alignments = [], [], []
  380. while len(mel_outputs) < decoder_inputs.size(0) - 1:
  381. if len(mel_outputs) == 0 or np.random.uniform(0.0, 1.0) <= self.p_teacher_forcing:
  382. decoder_input = torch.cat((decoder_inputs[len(mel_outputs)],
  383. f0s[len(mel_outputs)]), dim=1)
  384. else:
  385. decoder_input = torch.cat((self.prenet(mel_outputs[-1]),
  386. f0s[len(mel_outputs)]), dim=1)
  387. mel_output, gate_output, attention_weights = self.decode(
  388. decoder_input)
  389. mel_outputs += [mel_output.squeeze(1)]
  390. gate_outputs += [gate_output.squeeze()]
  391. alignments += [attention_weights]
  392. mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
  393. mel_outputs, gate_outputs, alignments)
  394. return mel_outputs, gate_outputs, alignments
  395. def inference(self, memory, f0s):
  396. """ Decoder inference
  397. PARAMS
  398. ------
  399. memory: Encoder outputs
  400. RETURNS
  401. -------
  402. mel_outputs: mel outputs from the decoder
  403. gate_outputs: gate outputs from the decoder
  404. alignments: sequence of attention weights from the decoder
  405. """
  406. decoder_input = self.get_go_frame(memory)
  407. self.initialize_decoder_states(memory, mask=None)
  408. f0_dummy = self.get_end_f0(f0s)
  409. f0s = torch.cat((f0s, f0_dummy), dim=2)
  410. f0s = F.relu(self.prenet_f0(f0s))
  411. f0s = f0s.permute(2, 0, 1)
  412. mel_outputs, gate_outputs, alignments = [], [], []
  413. while True:
  414. if len(mel_outputs) < len(f0s):
  415. f0 = f0s[len(mel_outputs)]
  416. else:
  417. f0 = f0s[-1] * 0
  418. decoder_input = torch.cat((self.prenet(decoder_input), f0), dim=1)
  419. mel_output, gate_output, alignment = self.decode(decoder_input)
  420. mel_outputs += [mel_output.squeeze(1)]
  421. gate_outputs += [gate_output]
  422. alignments += [alignment]
  423. if torch.sigmoid(gate_output.data) > self.gate_threshold:
  424. break
  425. elif len(mel_outputs) == self.max_decoder_steps:
  426. print("Warning! Reached max decoder steps")
  427. break
  428. decoder_input = mel_output
  429. mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
  430. mel_outputs, gate_outputs, alignments)
  431. return mel_outputs, gate_outputs, alignments
  432. def inference_noattention(self, memory, f0s, attention_map):
  433. """ Decoder inference
  434. PARAMS
  435. ------
  436. memory: Encoder outputs
  437. RETURNS
  438. -------
  439. mel_outputs: mel outputs from the decoder
  440. gate_outputs: gate outputs from the decoder
  441. alignments: sequence of attention weights from the decoder
  442. """
  443. decoder_input = self.get_go_frame(memory)
  444. self.initialize_decoder_states(memory, mask=None)
  445. f0_dummy = self.get_end_f0(f0s)
  446. f0s = torch.cat((f0s, f0_dummy), dim=2)
  447. f0s = F.relu(self.prenet_f0(f0s))
  448. f0s = f0s.permute(2, 0, 1)
  449. mel_outputs, gate_outputs, alignments = [], [], []
  450. for i in range(len(attention_map)):
  451. f0 = f0s[i]
  452. attention = attention_map[i]
  453. decoder_input = torch.cat((self.prenet(decoder_input), f0), dim=1)
  454. mel_output, gate_output, alignment = self.decode(decoder_input, attention)
  455. mel_outputs += [mel_output.squeeze(1)]
  456. gate_outputs += [gate_output]
  457. alignments += [alignment]
  458. decoder_input = mel_output
  459. mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
  460. mel_outputs, gate_outputs, alignments)
  461. return mel_outputs, gate_outputs, alignments
  462. class Tacotron2(nn.Module):
  463. def __init__(self, hparams):
  464. super(Tacotron2, self).__init__()
  465. self.mask_padding = hparams.mask_padding
  466. self.fp16_run = hparams.fp16_run
  467. self.n_mel_channels = hparams.n_mel_channels
  468. self.n_frames_per_step = hparams.n_frames_per_step
  469. self.embedding = nn.Embedding(
  470. hparams.n_symbols, hparams.symbols_embedding_dim)
  471. std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
  472. val = sqrt(3.0) * std # uniform bounds for std
  473. self.embedding.weight.data.uniform_(-val, val)
  474. self.encoder = Encoder(hparams)
  475. self.decoder = Decoder(hparams)
  476. self.postnet = Postnet(hparams)
  477. if hparams.with_gst:
  478. self.gst = GST(hparams)
  479. self.speaker_embedding = nn.Embedding(
  480. hparams.n_speakers, hparams.speaker_embedding_dim)
  481. def parse_batch(self, batch):
  482. text_padded, input_lengths, mel_padded, gate_padded, \
  483. output_lengths, speaker_ids, f0_padded = batch
  484. text_padded = to_gpu(text_padded).long()
  485. input_lengths = to_gpu(input_lengths).long()
  486. max_len = torch.max(input_lengths.data).item()
  487. mel_padded = to_gpu(mel_padded).float()
  488. gate_padded = to_gpu(gate_padded).float()
  489. output_lengths = to_gpu(output_lengths).long()
  490. speaker_ids = to_gpu(speaker_ids.data).long()
  491. f0_padded = to_gpu(f0_padded).float()
  492. return ((text_padded, input_lengths, mel_padded, max_len,
  493. output_lengths, speaker_ids, f0_padded),
  494. (mel_padded, gate_padded))
  495. def parse_output(self, outputs, output_lengths=None):
  496. if self.mask_padding and output_lengths is not None:
  497. mask = ~get_mask_from_lengths(output_lengths)
  498. mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
  499. mask = mask.permute(1, 0, 2)
  500. outputs[0].data.masked_fill_(mask, 0.0)
  501. outputs[1].data.masked_fill_(mask, 0.0)
  502. outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies
  503. return outputs
  504. def forward(self, inputs):
  505. inputs, input_lengths, targets, max_len, \
  506. output_lengths, speaker_ids, f0s = inputs
  507. input_lengths, output_lengths = input_lengths.data, output_lengths.data
  508. embedded_inputs = self.embedding(inputs).transpose(1, 2)
  509. embedded_text = self.encoder(embedded_inputs, input_lengths)
  510. embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
  511. embedded_gst = self.gst(targets, output_lengths)
  512. embedded_gst = embedded_gst.repeat(1, embedded_text.size(1), 1)
  513. embedded_speakers = embedded_speakers.repeat(1, embedded_text.size(1), 1)
  514. encoder_outputs = torch.cat(
  515. (embedded_text, embedded_gst, embedded_speakers), dim=2)
  516. mel_outputs, gate_outputs, alignments = self.decoder(
  517. encoder_outputs, targets, memory_lengths=input_lengths, f0s=f0s)
  518. mel_outputs_postnet = self.postnet(mel_outputs)
  519. mel_outputs_postnet = mel_outputs + mel_outputs_postnet
  520. return self.parse_output(
  521. [mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
  522. output_lengths)
  523. def inference(self, inputs):
  524. text, style_input, speaker_ids, f0s = inputs
  525. embedded_inputs = self.embedding(text).transpose(1, 2)
  526. embedded_text = self.encoder.inference(embedded_inputs)
  527. embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
  528. if hasattr(self, 'gst'):
  529. if isinstance(style_input, int):
  530. query = torch.zeros(1, 1, self.gst.encoder.ref_enc_gru_size).to(device)
  531. GST = torch.tanh(self.gst.stl.embed)
  532. key = GST[style_input].unsqueeze(0).expand(1, -1, -1)
  533. embedded_gst = self.gst.stl.attention(query, key)
  534. else:
  535. embedded_gst = self.gst(style_input)
  536. embedded_speakers = embedded_speakers.repeat(1, embedded_text.size(1), 1)
  537. if hasattr(self, 'gst'):
  538. embedded_gst = embedded_gst.repeat(1, embedded_text.size(1), 1)
  539. encoder_outputs = torch.cat(
  540. (embedded_text, embedded_gst, embedded_speakers), dim=2)
  541. else:
  542. encoder_outputs = torch.cat(
  543. (embedded_text, embedded_speakers), dim=2)
  544. mel_outputs, gate_outputs, alignments = self.decoder.inference(
  545. encoder_outputs, f0s)
  546. mel_outputs_postnet = self.postnet(mel_outputs)
  547. mel_outputs_postnet = mel_outputs + mel_outputs_postnet
  548. return self.parse_output(
  549. [mel_outputs, mel_outputs_postnet, gate_outputs, alignments])
  550. def inference_noattention(self, inputs):
  551. text, style_input, speaker_ids, f0s, attention_map = inputs
  552. embedded_inputs = self.embedding(text).transpose(1, 2)
  553. embedded_text = self.encoder.inference(embedded_inputs)
  554. embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
  555. if hasattr(self, 'gst'):
  556. if isinstance(style_input, int):
  557. query = torch.zeros(1, 1, self.gst.encoder.ref_enc_gru_size).to(device)
  558. GST = torch.tanh(self.gst.stl.embed)
  559. key = GST[style_input].unsqueeze(0).expand(1, -1, -1)
  560. embedded_gst = self.gst.stl.attention(query, key)
  561. else:
  562. embedded_gst = self.gst(style_input)
  563. embedded_speakers = embedded_speakers.repeat(1, embedded_text.size(1), 1)
  564. if hasattr(self, 'gst'):
  565. embedded_gst = embedded_gst.repeat(1, embedded_text.size(1), 1)
  566. encoder_outputs = torch.cat(
  567. (embedded_text, embedded_gst, embedded_speakers), dim=2)
  568. else:
  569. encoder_outputs = torch.cat(
  570. (embedded_text, embedded_speakers), dim=2)
  571. mel_outputs, gate_outputs, alignments = self.decoder.inference_noattention(
  572. encoder_outputs, f0s, attention_map)
  573. mel_outputs_postnet = self.postnet(mel_outputs)
  574. mel_outputs_postnet = mel_outputs + mel_outputs_postnet
  575. return self.parse_output(
  576. [mel_outputs, mel_outputs_postnet, gate_outputs, alignments])