modules.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805
  1. import math
  2. from dataclasses import dataclass
  3. import torch
  4. from encodec.quantization.core_vq import VectorQuantization
  5. from torch import nn
  6. from torch.nn import Conv1d, Conv2d, ConvTranspose1d
  7. from torch.nn import functional as F
  8. from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
  9. from fish_speech.models.vqgan.utils import convert_pad_shape, get_padding, init_weights
  10. LRELU_SLOPE = 0.1
  11. @dataclass
  12. class VQEncoderOutput:
  13. loss: torch.Tensor
  14. features: torch.Tensor
  15. class VQEncoder(nn.Module):
  16. def __init__(
  17. self,
  18. in_channels: int = 1024,
  19. channels: int = 192,
  20. num_mels: int = 128,
  21. num_heads: int = 2,
  22. num_feature_layers: int = 2,
  23. num_speaker_layers: int = 4,
  24. num_mixin_layers: int = 4,
  25. input_downsample: bool = True,
  26. code_book_size: int = 2048,
  27. freeze_vq: bool = False,
  28. ):
  29. super().__init__()
  30. # Feature Encoder
  31. down_sample = 2 if input_downsample else 1
  32. self.vq_in = nn.Linear(in_channels * down_sample, in_channels)
  33. self.vq = VectorQuantization(
  34. dim=in_channels,
  35. codebook_size=code_book_size,
  36. threshold_ema_dead_code=2,
  37. kmeans_init=True,
  38. kmeans_iters=50,
  39. )
  40. self.feature_in = nn.Linear(in_channels, channels)
  41. self.feature_blocks = nn.ModuleList(
  42. [
  43. TransformerBlock(
  44. channels,
  45. num_heads,
  46. window_size=4,
  47. window_heads_share=True,
  48. proximal_init=True,
  49. proximal_bias=False,
  50. use_relative_attn=True,
  51. )
  52. for _ in range(num_feature_layers)
  53. ]
  54. )
  55. # Speaker Encoder
  56. self.speaker_query = nn.Parameter(torch.randn(1, 1, channels))
  57. self.speaker_in = nn.Linear(num_mels, channels)
  58. self.speaker_blocks = nn.ModuleList(
  59. [
  60. TransformerBlock(
  61. channels,
  62. num_heads,
  63. use_relative_attn=False,
  64. )
  65. for _ in range(num_speaker_layers)
  66. ]
  67. )
  68. # Final Mixer
  69. self.mixer_in = nn.ModuleList(
  70. [
  71. TransformerBlock(
  72. channels,
  73. num_heads,
  74. window_size=4,
  75. window_heads_share=True,
  76. proximal_init=True,
  77. proximal_bias=False,
  78. use_relative_attn=True,
  79. )
  80. for _ in range(num_mixin_layers)
  81. ]
  82. )
  83. self.input_downsample = input_downsample
  84. if freeze_vq:
  85. for p in self.vq.parameters():
  86. p.requires_grad = False
  87. for p in self.vq_in.parameters():
  88. p.requires_grad = False
  89. def forward(self, x, mels, key_padding_mask=None):
  90. # x: (batch, seq_len, channels)
  91. # x: (batch, seq_len, 128)
  92. if self.input_downsample and key_padding_mask is not None:
  93. key_padding_mask = key_padding_mask[:, ::2]
  94. # Merge Channels
  95. if self.input_downsample:
  96. feature_0, feature_1 = x[:, ::2], x[:, 1::2]
  97. min_len = min(feature_0.size(1), feature_1.size(1))
  98. x = torch.cat([feature_0[:, :min_len], feature_1[:, :min_len]], dim=2)
  99. # Encode Features
  100. features = self.vq_in(x)
  101. assert key_padding_mask.size(1) == features.size(
  102. 1
  103. ), f"key_padding_mask shape {key_padding_mask.size()} is not (batch_size, seq_len)"
  104. features, _, loss = self.vq(features, mask=~key_padding_mask)
  105. if self.input_downsample:
  106. features = F.interpolate(
  107. features.transpose(1, 2), scale_factor=2
  108. ).transpose(1, 2)
  109. features = self.feature_in(features)
  110. for block in self.feature_blocks:
  111. features = block(features, key_padding_mask=key_padding_mask)
  112. # Encode Speaker
  113. speaker = self.speaker_in(x)
  114. speaker = torch.cat(
  115. [self.speaker_query.expand(speaker.shape[0], -1, -1), speaker], dim=1
  116. )
  117. for block in self.speaker_blocks:
  118. speaker = block(mels, key_padding_mask=key_padding_mask)
  119. # Mix
  120. x = features + speaker[:, :1]
  121. for block in self.mixer_in:
  122. x = block(x, key_padding_mask=key_padding_mask)
  123. return VQEncoderOutput(
  124. loss=loss,
  125. features=x.transpose(1, 2),
  126. )
  127. class TransformerBlock(nn.Module):
  128. def __init__(
  129. self,
  130. channels,
  131. n_heads,
  132. mlp_ratio=4 * 2 / 3,
  133. p_dropout=0.0,
  134. window_size=4,
  135. window_heads_share=True,
  136. proximal_init=True,
  137. proximal_bias=False,
  138. use_relative_attn=True,
  139. ):
  140. super().__init__()
  141. self.attn_norm = RMSNorm(channels)
  142. if use_relative_attn:
  143. self.attn = RelativeAttention(
  144. channels,
  145. n_heads,
  146. p_dropout,
  147. window_size,
  148. window_heads_share,
  149. proximal_init,
  150. proximal_bias,
  151. )
  152. else:
  153. self.attn = nn.MultiheadAttention(
  154. embed_dim=channels,
  155. num_heads=n_heads,
  156. dropout=p_dropout,
  157. batch_first=True,
  158. )
  159. self.mlp_norm = RMSNorm(channels)
  160. self.mlp = SwiGLU(channels, int(channels * mlp_ratio), channels, drop=p_dropout)
  161. def forward(self, x, key_padding_mask=None):
  162. norm_x = self.attn_norm(x)
  163. if isinstance(self.attn, RelativeAttention):
  164. attn = self.attn(norm_x, key_padding_mask=key_padding_mask)
  165. else:
  166. attn, _ = self.attn(
  167. norm_x, norm_x, norm_x, key_padding_mask=key_padding_mask
  168. )
  169. x = x + attn
  170. x = x + self.mlp(self.mlp_norm(x))
  171. return x
  172. class SwiGLU(nn.Module):
  173. """
  174. Swish-Gated Linear Unit (SwiGLU) activation function
  175. """
  176. def __init__(
  177. self,
  178. in_features,
  179. hidden_features=None,
  180. out_features=None,
  181. bias=True,
  182. drop=0.0,
  183. ):
  184. super().__init__()
  185. out_features = out_features or in_features
  186. hidden_features = hidden_features or in_features
  187. assert hidden_features % 2 == 0
  188. self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
  189. self.act = nn.SiLU()
  190. self.drop1 = nn.Dropout(drop)
  191. self.norm = RMSNorm(hidden_features // 2)
  192. self.fc2 = nn.Linear(hidden_features // 2, out_features, bias=bias)
  193. self.drop2 = nn.Dropout(drop)
  194. def init_weights(self):
  195. # override init of fc1 w/ gate portion set to weight near zero, bias=1
  196. fc1_mid = self.fc1.bias.shape[0] // 2
  197. nn.init.ones_(self.fc1.bias[fc1_mid:])
  198. nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6)
  199. def forward(self, x):
  200. x = self.fc1(x)
  201. x1, x2 = x.chunk(2, dim=-1)
  202. x = x1 * self.act(x2)
  203. x = self.drop1(x)
  204. x = self.norm(x)
  205. x = self.fc2(x)
  206. x = self.drop2(x)
  207. return x
  208. class RMSNorm(nn.Module):
  209. def __init__(self, hidden_size, eps=1e-6):
  210. """
  211. LlamaRMSNorm is equivalent to T5LayerNorm
  212. """
  213. super().__init__()
  214. self.weight = nn.Parameter(torch.ones(hidden_size))
  215. self.variance_epsilon = eps
  216. def forward(self, hidden_states):
  217. input_dtype = hidden_states.dtype
  218. hidden_states = hidden_states.to(torch.float32)
  219. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  220. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  221. return self.weight * hidden_states.to(input_dtype)
  222. class RelativeAttention(nn.Module):
  223. def __init__(
  224. self,
  225. channels,
  226. n_heads,
  227. p_dropout=0.0,
  228. window_size=4,
  229. window_heads_share=True,
  230. proximal_init=True,
  231. proximal_bias=False,
  232. ):
  233. super().__init__()
  234. assert channels % n_heads == 0
  235. self.channels = channels
  236. self.n_heads = n_heads
  237. self.p_dropout = p_dropout
  238. self.window_size = window_size
  239. self.heads_share = window_heads_share
  240. self.proximal_init = proximal_init
  241. self.proximal_bias = proximal_bias
  242. self.k_channels = channels // n_heads
  243. self.qkv = nn.Linear(channels, channels * 3)
  244. self.drop = nn.Dropout(p_dropout)
  245. if window_size is not None:
  246. n_heads_rel = 1 if window_heads_share else n_heads
  247. rel_stddev = self.k_channels**-0.5
  248. self.emb_rel_k = nn.Parameter(
  249. torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
  250. * rel_stddev
  251. )
  252. self.emb_rel_v = nn.Parameter(
  253. torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
  254. * rel_stddev
  255. )
  256. nn.init.xavier_uniform_(self.qkv.weight)
  257. if proximal_init:
  258. with torch.no_grad():
  259. # Sync qk weights
  260. self.qkv.weight.data[: self.channels] = self.qkv.weight.data[
  261. self.channels : self.channels * 2
  262. ]
  263. self.qkv.bias.data[: self.channels] = self.qkv.bias.data[
  264. self.channels : self.channels * 2
  265. ]
  266. def forward(self, x, key_padding_mask=None):
  267. # x: (batch, seq_len, channels)
  268. batch_size, seq_len, _ = x.size()
  269. qkv = (
  270. self.qkv(x)
  271. .reshape(batch_size, seq_len, 3, self.n_heads, self.k_channels)
  272. .permute(2, 0, 3, 1, 4)
  273. )
  274. query, key, value = torch.unbind(qkv, dim=0)
  275. scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
  276. if self.window_size is not None:
  277. key_relative_embeddings = self._get_relative_embeddings(
  278. self.emb_rel_k, seq_len
  279. )
  280. rel_logits = self._matmul_with_relative_keys(
  281. query / math.sqrt(self.k_channels), key_relative_embeddings
  282. )
  283. scores_local = self._relative_position_to_absolute_position(rel_logits)
  284. scores = scores + scores_local
  285. if self.proximal_bias:
  286. scores = scores + self._attention_bias_proximal(seq_len).to(
  287. device=scores.device, dtype=scores.dtype
  288. )
  289. # key_padding_mask: (batch, seq_len)
  290. if key_padding_mask is not None:
  291. assert key_padding_mask.size() == (
  292. batch_size,
  293. seq_len,
  294. ), f"key_padding_mask shape {key_padding_mask.size()} is not (batch_size, seq_len)"
  295. assert (
  296. key_padding_mask.dtype == torch.bool
  297. ), f"key_padding_mask dtype {key_padding_mask.dtype} is not bool"
  298. key_padding_mask = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(
  299. -1, self.n_heads, -1, -1
  300. )
  301. scores = scores.masked_fill(key_padding_mask, float("-inf"))
  302. p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
  303. p_attn = self.drop(p_attn)
  304. output = torch.matmul(p_attn, value)
  305. if self.window_size is not None:
  306. relative_weights = self._absolute_position_to_relative_position(p_attn)
  307. value_relative_embeddings = self._get_relative_embeddings(
  308. self.emb_rel_v, seq_len
  309. )
  310. output = output + self._matmul_with_relative_values(
  311. relative_weights, value_relative_embeddings
  312. )
  313. return output.reshape(batch_size, seq_len, self.n_heads * self.k_channels)
  314. def _matmul_with_relative_values(self, x, y):
  315. """
  316. x: [b, h, l, m]
  317. y: [h or 1, m, d]
  318. ret: [b, h, l, d]
  319. """
  320. ret = torch.matmul(x, y.unsqueeze(0))
  321. return ret
  322. def _matmul_with_relative_keys(self, x, y):
  323. """
  324. x: [b, h, l, d]
  325. y: [h or 1, m, d]
  326. ret: [b, h, l, m]
  327. """
  328. ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
  329. return ret
  330. def _get_relative_embeddings(self, relative_embeddings, length):
  331. max_relative_position = 2 * self.window_size + 1
  332. # Pad first before slice to avoid using cond ops.
  333. pad_length = max(length - (self.window_size + 1), 0)
  334. slice_start_position = max((self.window_size + 1) - length, 0)
  335. slice_end_position = slice_start_position + 2 * length - 1
  336. if pad_length > 0:
  337. padded_relative_embeddings = F.pad(
  338. relative_embeddings,
  339. convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
  340. )
  341. else:
  342. padded_relative_embeddings = relative_embeddings
  343. used_relative_embeddings = padded_relative_embeddings[
  344. :, slice_start_position:slice_end_position
  345. ]
  346. return used_relative_embeddings
  347. def _relative_position_to_absolute_position(self, x):
  348. """
  349. x: [b, h, l, 2*l-1]
  350. ret: [b, h, l, l]
  351. """
  352. batch, heads, length, _ = x.size()
  353. # Concat columns of pad to shift from relative to absolute indexing.
  354. x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
  355. # Concat extra elements so to add up to shape (len+1, 2*len-1).
  356. x_flat = x.view([batch, heads, length * 2 * length])
  357. x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
  358. # Reshape and slice out the padded elements.
  359. x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
  360. :, :, :length, length - 1 :
  361. ]
  362. return x_final
  363. def _absolute_position_to_relative_position(self, x):
  364. """
  365. x: [b, h, l, l]
  366. ret: [b, h, l, 2*l-1]
  367. """
  368. batch, heads, length, _ = x.size()
  369. # pad along column
  370. x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
  371. x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
  372. # add 0's in the beginning that will skew the elements after reshape
  373. x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
  374. x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
  375. return x_final
  376. def _attention_bias_proximal(self, length):
  377. """Bias for self-attention to encourage attention to close positions.
  378. Args:
  379. length: an integer scalar.
  380. Returns:
  381. a Tensor with shape [1, 1, length, length]
  382. """
  383. r = torch.arange(length, dtype=torch.float32)
  384. diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
  385. return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
  386. class ResBlock1(torch.nn.Module):
  387. def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
  388. super(ResBlock1, self).__init__()
  389. self.convs1 = nn.ModuleList(
  390. [
  391. weight_norm(
  392. Conv1d(
  393. channels,
  394. channels,
  395. kernel_size,
  396. 1,
  397. dilation=dilation[0],
  398. padding=get_padding(kernel_size, dilation[0]),
  399. )
  400. ),
  401. weight_norm(
  402. Conv1d(
  403. channels,
  404. channels,
  405. kernel_size,
  406. 1,
  407. dilation=dilation[1],
  408. padding=get_padding(kernel_size, dilation[1]),
  409. )
  410. ),
  411. weight_norm(
  412. Conv1d(
  413. channels,
  414. channels,
  415. kernel_size,
  416. 1,
  417. dilation=dilation[2],
  418. padding=get_padding(kernel_size, dilation[2]),
  419. )
  420. ),
  421. ]
  422. )
  423. self.convs1.apply(init_weights)
  424. self.convs2 = nn.ModuleList(
  425. [
  426. weight_norm(
  427. Conv1d(
  428. channels,
  429. channels,
  430. kernel_size,
  431. 1,
  432. dilation=1,
  433. padding=get_padding(kernel_size, 1),
  434. )
  435. ),
  436. weight_norm(
  437. Conv1d(
  438. channels,
  439. channels,
  440. kernel_size,
  441. 1,
  442. dilation=1,
  443. padding=get_padding(kernel_size, 1),
  444. )
  445. ),
  446. weight_norm(
  447. Conv1d(
  448. channels,
  449. channels,
  450. kernel_size,
  451. 1,
  452. dilation=1,
  453. padding=get_padding(kernel_size, 1),
  454. )
  455. ),
  456. ]
  457. )
  458. self.convs2.apply(init_weights)
  459. def forward(self, x, x_mask=None):
  460. for c1, c2 in zip(self.convs1, self.convs2):
  461. xt = F.leaky_relu(x, LRELU_SLOPE)
  462. if x_mask is not None:
  463. xt = xt * x_mask
  464. xt = c1(xt)
  465. xt = F.leaky_relu(xt, LRELU_SLOPE)
  466. if x_mask is not None:
  467. xt = xt * x_mask
  468. xt = c2(xt)
  469. x = xt + x
  470. if x_mask is not None:
  471. x = x * x_mask
  472. return x
  473. def remove_weight_norm(self):
  474. for l in self.convs1:
  475. remove_weight_norm(l)
  476. for l in self.convs2:
  477. remove_weight_norm(l)
  478. class ResBlock2(torch.nn.Module):
  479. def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
  480. super(ResBlock2, self).__init__()
  481. self.convs = nn.ModuleList(
  482. [
  483. weight_norm(
  484. Conv1d(
  485. channels,
  486. channels,
  487. kernel_size,
  488. 1,
  489. dilation=dilation[0],
  490. padding=get_padding(kernel_size, dilation[0]),
  491. )
  492. ),
  493. weight_norm(
  494. Conv1d(
  495. channels,
  496. channels,
  497. kernel_size,
  498. 1,
  499. dilation=dilation[1],
  500. padding=get_padding(kernel_size, dilation[1]),
  501. )
  502. ),
  503. ]
  504. )
  505. self.convs.apply(init_weights)
  506. def forward(self, x, x_mask=None):
  507. for c in self.convs:
  508. xt = F.leaky_relu(x, LRELU_SLOPE)
  509. if x_mask is not None:
  510. xt = xt * x_mask
  511. xt = c(xt)
  512. x = xt + x
  513. if x_mask is not None:
  514. x = x * x_mask
  515. return x
  516. def remove_weight_norm(self):
  517. for l in self.convs:
  518. remove_weight_norm(l)
  519. class Generator(nn.Module):
  520. def __init__(
  521. self,
  522. initial_channel,
  523. resblock,
  524. resblock_kernel_sizes,
  525. resblock_dilation_sizes,
  526. upsample_rates,
  527. upsample_initial_channel,
  528. upsample_kernel_sizes,
  529. ):
  530. super(Generator, self).__init__()
  531. self.num_kernels = len(resblock_kernel_sizes)
  532. self.num_upsamples = len(upsample_rates)
  533. self.conv_pre = Conv1d(
  534. initial_channel, upsample_initial_channel, 7, 1, padding=3
  535. )
  536. resblock = ResBlock1 if resblock == "1" else ResBlock2
  537. self.ups = nn.ModuleList()
  538. for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
  539. self.ups.append(
  540. weight_norm(
  541. ConvTranspose1d(
  542. upsample_initial_channel // (2**i),
  543. upsample_initial_channel // (2 ** (i + 1)),
  544. k,
  545. u,
  546. padding=(k - u) // 2,
  547. )
  548. )
  549. )
  550. self.resblocks = nn.ModuleList()
  551. for i in range(len(self.ups)):
  552. ch = upsample_initial_channel // (2 ** (i + 1))
  553. for j, (k, d) in enumerate(
  554. zip(resblock_kernel_sizes, resblock_dilation_sizes)
  555. ):
  556. self.resblocks.append(resblock(ch, k, d))
  557. self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
  558. self.ups.apply(init_weights)
  559. def forward(self, x):
  560. x = self.conv_pre(x)
  561. for i in range(self.num_upsamples):
  562. x = F.leaky_relu(x, LRELU_SLOPE)
  563. x = self.ups[i](x)
  564. xs = None
  565. for j in range(self.num_kernels):
  566. if xs is None:
  567. xs = self.resblocks[i * self.num_kernels + j](x)
  568. else:
  569. xs += self.resblocks[i * self.num_kernels + j](x)
  570. x = xs / self.num_kernels
  571. x = F.leaky_relu(x)
  572. x = self.conv_post(x)
  573. x = torch.tanh(x)
  574. return x
  575. def remove_weight_norm(self):
  576. print("Removing weight norm...")
  577. for l in self.ups:
  578. remove_weight_norm(l)
  579. for l in self.resblocks:
  580. l.remove_weight_norm()
  581. class DiscriminatorP(nn.Module):
  582. def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
  583. super(DiscriminatorP, self).__init__()
  584. self.period = period
  585. self.use_spectral_norm = use_spectral_norm
  586. norm_f = weight_norm if use_spectral_norm == False else spectral_norm
  587. self.convs = nn.ModuleList(
  588. [
  589. norm_f(
  590. Conv2d(
  591. 1,
  592. 32,
  593. (kernel_size, 1),
  594. (stride, 1),
  595. padding=(get_padding(kernel_size, 1), 0),
  596. )
  597. ),
  598. norm_f(
  599. Conv2d(
  600. 32,
  601. 128,
  602. (kernel_size, 1),
  603. (stride, 1),
  604. padding=(get_padding(kernel_size, 1), 0),
  605. )
  606. ),
  607. norm_f(
  608. Conv2d(
  609. 128,
  610. 512,
  611. (kernel_size, 1),
  612. (stride, 1),
  613. padding=(get_padding(kernel_size, 1), 0),
  614. )
  615. ),
  616. norm_f(
  617. Conv2d(
  618. 512,
  619. 1024,
  620. (kernel_size, 1),
  621. (stride, 1),
  622. padding=(get_padding(kernel_size, 1), 0),
  623. )
  624. ),
  625. norm_f(
  626. Conv2d(
  627. 1024,
  628. 1024,
  629. (kernel_size, 1),
  630. 1,
  631. padding=(get_padding(kernel_size, 1), 0),
  632. )
  633. ),
  634. ]
  635. )
  636. self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
  637. def forward(self, x):
  638. fmap = []
  639. # 1d to 2d
  640. b, c, t = x.shape
  641. if t % self.period != 0: # pad first
  642. n_pad = self.period - (t % self.period)
  643. x = F.pad(x, (0, n_pad), "reflect")
  644. t = t + n_pad
  645. x = x.view(b, c, t // self.period, self.period)
  646. for l in self.convs:
  647. x = l(x)
  648. x = F.leaky_relu(x, LRELU_SLOPE)
  649. fmap.append(x)
  650. x = self.conv_post(x)
  651. fmap.append(x)
  652. x = torch.flatten(x, 1, -1)
  653. return x, fmap
  654. class DiscriminatorS(nn.Module):
  655. def __init__(self, use_spectral_norm=False):
  656. super(DiscriminatorS, self).__init__()
  657. norm_f = weight_norm if use_spectral_norm == False else spectral_norm
  658. self.convs = nn.ModuleList(
  659. [
  660. norm_f(Conv1d(1, 16, 15, 1, padding=7)),
  661. norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
  662. norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
  663. norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
  664. norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
  665. norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
  666. ]
  667. )
  668. self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
  669. def forward(self, x):
  670. fmap = []
  671. for l in self.convs:
  672. x = l(x)
  673. x = F.leaky_relu(x, LRELU_SLOPE)
  674. fmap.append(x)
  675. x = self.conv_post(x)
  676. fmap.append(x)
  677. x = torch.flatten(x, 1, -1)
  678. return x, fmap
  679. class EnsembleDiscriminator(nn.Module):
  680. def __init__(self, use_spectral_norm=False):
  681. super(EnsembleDiscriminator, self).__init__()
  682. periods = [2, 3, 5, 7, 11]
  683. discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
  684. discs = discs + [
  685. DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
  686. ]
  687. self.discriminators = nn.ModuleList(discs)
  688. def forward(self, y, y_hat):
  689. y_d_rs = []
  690. y_d_gs = []
  691. fmap_rs = []
  692. fmap_gs = []
  693. for i, d in enumerate(self.discriminators):
  694. y_d_r, fmap_r = d(y)
  695. y_d_g, fmap_g = d(y_hat)
  696. y_d_rs.append(y_d_r)
  697. y_d_gs.append(y_d_g)
  698. fmap_rs.append(fmap_r)
  699. fmap_gs.append(fmap_g)
  700. return y_d_rs, y_d_gs, fmap_rs, fmap_gs