modules.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573
  1. import math
  2. import torch
  3. from torch import nn
  4. from torch.nn import Conv1d, Conv2d, ConvTranspose1d
  5. from torch.nn import functional as F
  6. from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
  7. from fish_speech.models.hubert_vq.utils import (
  8. convert_pad_shape,
  9. get_padding,
  10. init_weights,
  11. )
  12. LRELU_SLOPE = 0.1
  13. class VQEncoder(nn.Module):
  14. def __init__(self, *args, **kwargs) -> None:
  15. super().__init__(*args, **kwargs)
  16. encoder_layer = nn.TransformerEncoderLayer(
  17. d_model=256, nhead=4, dim_feedforward=1024, dropout=0.1, activation="gelu"
  18. )
  19. self.encoder = nn.TransformerEncoder(
  20. encoder_layer, num_layers=6, norm=nn.LayerNorm(256)
  21. )
  22. class RelativeAttention(nn.Module):
  23. def __init__(
  24. self,
  25. channels,
  26. n_heads,
  27. p_dropout=0.0,
  28. window_size=4,
  29. window_heads_share=True,
  30. proximal_init=True,
  31. proximal_bias=False,
  32. ):
  33. super().__init__()
  34. assert channels % n_heads == 0
  35. self.channels = channels
  36. self.n_heads = n_heads
  37. self.p_dropout = p_dropout
  38. self.window_size = window_size
  39. self.heads_share = window_heads_share
  40. self.proximal_init = proximal_init
  41. self.proximal_bias = proximal_bias
  42. self.k_channels = channels // n_heads
  43. self.qkv = nn.Linear(channels, channels * 3)
  44. self.drop = nn.Dropout(p_dropout)
  45. if window_size is not None:
  46. n_heads_rel = 1 if window_heads_share else n_heads
  47. rel_stddev = self.k_channels**-0.5
  48. self.emb_rel_k = nn.Parameter(
  49. torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
  50. * rel_stddev
  51. )
  52. self.emb_rel_v = nn.Parameter(
  53. torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
  54. * rel_stddev
  55. )
  56. nn.init.xavier_uniform_(self.qkv.weight)
  57. if proximal_init:
  58. with torch.no_grad():
  59. # Sync qk weights
  60. self.qkv.weight.data[: self.channels] = self.qkv.weight.data[
  61. self.channels : self.channels * 2
  62. ]
  63. self.qkv.bias.data[: self.channels] = self.qkv.bias.data[
  64. self.channels : self.channels * 2
  65. ]
  66. def forward(self, x, key_padding_mask=None):
  67. # x: (batch, seq_len, channels)
  68. batch_size, seq_len, _ = x.size()
  69. qkv = (
  70. self.qkv(x)
  71. .reshape(batch_size, seq_len, 3, self.n_heads, self.k_channels)
  72. .permute(2, 0, 3, 1, 4)
  73. )
  74. query, key, value = torch.unbind(qkv, dim=0)
  75. scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
  76. if self.window_size is not None:
  77. key_relative_embeddings = self._get_relative_embeddings(
  78. self.emb_rel_k, seq_len
  79. )
  80. rel_logits = self._matmul_with_relative_keys(
  81. query / math.sqrt(self.k_channels), key_relative_embeddings
  82. )
  83. scores_local = self._relative_position_to_absolute_position(rel_logits)
  84. scores = scores + scores_local
  85. if self.proximal_bias:
  86. scores = scores + self._attention_bias_proximal(seq_len).to(
  87. device=scores.device, dtype=scores.dtype
  88. )
  89. # key_padding_mask: (batch, seq_len)
  90. if key_padding_mask is not None:
  91. assert key_padding_mask.size() == (
  92. batch_size,
  93. seq_len,
  94. ), f"key_padding_mask shape {key_padding_mask.size()} is not (batch_size, seq_len)"
  95. assert (
  96. key_padding_mask.dtype == torch.bool
  97. ), f"key_padding_mask dtype {key_padding_mask.dtype} is not bool"
  98. key_padding_mask = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(
  99. -1, self.n_heads, -1, -1
  100. )
  101. print(key_padding_mask.shape, scores.shape)
  102. scores = scores.masked_fill(key_padding_mask, float("-inf"))
  103. print(scores[0, 0])
  104. p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
  105. p_attn = self.drop(p_attn)
  106. output = torch.matmul(p_attn, value)
  107. if self.window_size is not None:
  108. relative_weights = self._absolute_position_to_relative_position(p_attn)
  109. value_relative_embeddings = self._get_relative_embeddings(
  110. self.emb_rel_v, seq_len
  111. )
  112. output = output + self._matmul_with_relative_values(
  113. relative_weights, value_relative_embeddings
  114. )
  115. return output.reshape(batch_size, seq_len, self.n_heads * self.k_channels)
  116. def _matmul_with_relative_values(self, x, y):
  117. """
  118. x: [b, h, l, m]
  119. y: [h or 1, m, d]
  120. ret: [b, h, l, d]
  121. """
  122. ret = torch.matmul(x, y.unsqueeze(0))
  123. return ret
  124. def _matmul_with_relative_keys(self, x, y):
  125. """
  126. x: [b, h, l, d]
  127. y: [h or 1, m, d]
  128. ret: [b, h, l, m]
  129. """
  130. ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
  131. return ret
  132. def _get_relative_embeddings(self, relative_embeddings, length):
  133. max_relative_position = 2 * self.window_size + 1
  134. # Pad first before slice to avoid using cond ops.
  135. pad_length = max(length - (self.window_size + 1), 0)
  136. slice_start_position = max((self.window_size + 1) - length, 0)
  137. slice_end_position = slice_start_position + 2 * length - 1
  138. if pad_length > 0:
  139. padded_relative_embeddings = F.pad(
  140. relative_embeddings,
  141. convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
  142. )
  143. else:
  144. padded_relative_embeddings = relative_embeddings
  145. used_relative_embeddings = padded_relative_embeddings[
  146. :, slice_start_position:slice_end_position
  147. ]
  148. return used_relative_embeddings
  149. def _relative_position_to_absolute_position(self, x):
  150. """
  151. x: [b, h, l, 2*l-1]
  152. ret: [b, h, l, l]
  153. """
  154. batch, heads, length, _ = x.size()
  155. # Concat columns of pad to shift from relative to absolute indexing.
  156. x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
  157. # Concat extra elements so to add up to shape (len+1, 2*len-1).
  158. x_flat = x.view([batch, heads, length * 2 * length])
  159. x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
  160. # Reshape and slice out the padded elements.
  161. x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
  162. :, :, :length, length - 1 :
  163. ]
  164. return x_final
  165. def _absolute_position_to_relative_position(self, x):
  166. """
  167. x: [b, h, l, l]
  168. ret: [b, h, l, 2*l-1]
  169. """
  170. batch, heads, length, _ = x.size()
  171. # pad along column
  172. x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
  173. x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
  174. # add 0's in the beginning that will skew the elements after reshape
  175. x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
  176. x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
  177. return x_final
  178. def _attention_bias_proximal(self, length):
  179. """Bias for self-attention to encourage attention to close positions.
  180. Args:
  181. length: an integer scalar.
  182. Returns:
  183. a Tensor with shape [1, 1, length, length]
  184. """
  185. r = torch.arange(length, dtype=torch.float32)
  186. diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
  187. return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
  188. class ResBlock1(torch.nn.Module):
  189. def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
  190. super(ResBlock1, self).__init__()
  191. self.convs1 = nn.ModuleList(
  192. [
  193. weight_norm(
  194. Conv1d(
  195. channels,
  196. channels,
  197. kernel_size,
  198. 1,
  199. dilation=dilation[0],
  200. padding=get_padding(kernel_size, dilation[0]),
  201. )
  202. ),
  203. weight_norm(
  204. Conv1d(
  205. channels,
  206. channels,
  207. kernel_size,
  208. 1,
  209. dilation=dilation[1],
  210. padding=get_padding(kernel_size, dilation[1]),
  211. )
  212. ),
  213. weight_norm(
  214. Conv1d(
  215. channels,
  216. channels,
  217. kernel_size,
  218. 1,
  219. dilation=dilation[2],
  220. padding=get_padding(kernel_size, dilation[2]),
  221. )
  222. ),
  223. ]
  224. )
  225. self.convs1.apply(init_weights)
  226. self.convs2 = nn.ModuleList(
  227. [
  228. weight_norm(
  229. Conv1d(
  230. channels,
  231. channels,
  232. kernel_size,
  233. 1,
  234. dilation=1,
  235. padding=get_padding(kernel_size, 1),
  236. )
  237. ),
  238. weight_norm(
  239. Conv1d(
  240. channels,
  241. channels,
  242. kernel_size,
  243. 1,
  244. dilation=1,
  245. padding=get_padding(kernel_size, 1),
  246. )
  247. ),
  248. weight_norm(
  249. Conv1d(
  250. channels,
  251. channels,
  252. kernel_size,
  253. 1,
  254. dilation=1,
  255. padding=get_padding(kernel_size, 1),
  256. )
  257. ),
  258. ]
  259. )
  260. self.convs2.apply(init_weights)
  261. def forward(self, x, x_mask=None):
  262. for c1, c2 in zip(self.convs1, self.convs2):
  263. xt = F.leaky_relu(x, LRELU_SLOPE)
  264. if x_mask is not None:
  265. xt = xt * x_mask
  266. xt = c1(xt)
  267. xt = F.leaky_relu(xt, LRELU_SLOPE)
  268. if x_mask is not None:
  269. xt = xt * x_mask
  270. xt = c2(xt)
  271. x = xt + x
  272. if x_mask is not None:
  273. x = x * x_mask
  274. return x
  275. def remove_weight_norm(self):
  276. for l in self.convs1:
  277. remove_weight_norm(l)
  278. for l in self.convs2:
  279. remove_weight_norm(l)
  280. class ResBlock2(torch.nn.Module):
  281. def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
  282. super(ResBlock2, self).__init__()
  283. self.convs = nn.ModuleList(
  284. [
  285. weight_norm(
  286. Conv1d(
  287. channels,
  288. channels,
  289. kernel_size,
  290. 1,
  291. dilation=dilation[0],
  292. padding=get_padding(kernel_size, dilation[0]),
  293. )
  294. ),
  295. weight_norm(
  296. Conv1d(
  297. channels,
  298. channels,
  299. kernel_size,
  300. 1,
  301. dilation=dilation[1],
  302. padding=get_padding(kernel_size, dilation[1]),
  303. )
  304. ),
  305. ]
  306. )
  307. self.convs.apply(init_weights)
  308. def forward(self, x, x_mask=None):
  309. for c in self.convs:
  310. xt = F.leaky_relu(x, LRELU_SLOPE)
  311. if x_mask is not None:
  312. xt = xt * x_mask
  313. xt = c(xt)
  314. x = xt + x
  315. if x_mask is not None:
  316. x = x * x_mask
  317. return x
  318. def remove_weight_norm(self):
  319. for l in self.convs:
  320. remove_weight_norm(l)
  321. class Generator(nn.Module):
  322. def __init__(
  323. self,
  324. initial_channel,
  325. resblock,
  326. resblock_kernel_sizes,
  327. resblock_dilation_sizes,
  328. upsample_rates,
  329. upsample_initial_channel,
  330. upsample_kernel_sizes,
  331. gin_channels=0,
  332. ):
  333. super(Generator, self).__init__()
  334. self.num_kernels = len(resblock_kernel_sizes)
  335. self.num_upsamples = len(upsample_rates)
  336. self.conv_pre = Conv1d(
  337. initial_channel, upsample_initial_channel, 7, 1, padding=3
  338. )
  339. resblock = ResBlock1 if resblock == "1" else ResBlock2
  340. self.ups = nn.ModuleList()
  341. for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
  342. self.ups.append(
  343. weight_norm(
  344. ConvTranspose1d(
  345. upsample_initial_channel // (2**i),
  346. upsample_initial_channel // (2 ** (i + 1)),
  347. k,
  348. u,
  349. padding=(k - u) // 2,
  350. )
  351. )
  352. )
  353. self.resblocks = nn.ModuleList()
  354. for i in range(len(self.ups)):
  355. ch = upsample_initial_channel // (2 ** (i + 1))
  356. for j, (k, d) in enumerate(
  357. zip(resblock_kernel_sizes, resblock_dilation_sizes)
  358. ):
  359. self.resblocks.append(resblock(ch, k, d))
  360. self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
  361. self.ups.apply(init_weights)
  362. if gin_channels != 0:
  363. self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
  364. def forward(self, x, g=None):
  365. x = self.conv_pre(x)
  366. if g is not None:
  367. x = x + self.cond(g)
  368. for i in range(self.num_upsamples):
  369. x = F.leaky_relu(x, LRELU_SLOPE)
  370. x = self.ups[i](x)
  371. xs = None
  372. for j in range(self.num_kernels):
  373. if xs is None:
  374. xs = self.resblocks[i * self.num_kernels + j](x)
  375. else:
  376. xs += self.resblocks[i * self.num_kernels + j](x)
  377. x = xs / self.num_kernels
  378. x = F.leaky_relu(x)
  379. x = self.conv_post(x)
  380. x = torch.tanh(x)
  381. return x
  382. def remove_weight_norm(self):
  383. print("Removing weight norm...")
  384. for l in self.ups:
  385. remove_weight_norm(l)
  386. for l in self.resblocks:
  387. l.remove_weight_norm()
  388. class DiscriminatorP(nn.Module):
  389. def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
  390. super(DiscriminatorP, self).__init__()
  391. self.period = period
  392. self.use_spectral_norm = use_spectral_norm
  393. norm_f = weight_norm if use_spectral_norm == False else spectral_norm
  394. self.convs = nn.ModuleList(
  395. [
  396. norm_f(
  397. Conv2d(
  398. 1,
  399. 32,
  400. (kernel_size, 1),
  401. (stride, 1),
  402. padding=(get_padding(kernel_size, 1), 0),
  403. )
  404. ),
  405. norm_f(
  406. Conv2d(
  407. 32,
  408. 128,
  409. (kernel_size, 1),
  410. (stride, 1),
  411. padding=(get_padding(kernel_size, 1), 0),
  412. )
  413. ),
  414. norm_f(
  415. Conv2d(
  416. 128,
  417. 512,
  418. (kernel_size, 1),
  419. (stride, 1),
  420. padding=(get_padding(kernel_size, 1), 0),
  421. )
  422. ),
  423. norm_f(
  424. Conv2d(
  425. 512,
  426. 1024,
  427. (kernel_size, 1),
  428. (stride, 1),
  429. padding=(get_padding(kernel_size, 1), 0),
  430. )
  431. ),
  432. norm_f(
  433. Conv2d(
  434. 1024,
  435. 1024,
  436. (kernel_size, 1),
  437. 1,
  438. padding=(get_padding(kernel_size, 1), 0),
  439. )
  440. ),
  441. ]
  442. )
  443. self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
  444. def forward(self, x):
  445. fmap = []
  446. # 1d to 2d
  447. b, c, t = x.shape
  448. if t % self.period != 0: # pad first
  449. n_pad = self.period - (t % self.period)
  450. x = F.pad(x, (0, n_pad), "reflect")
  451. t = t + n_pad
  452. x = x.view(b, c, t // self.period, self.period)
  453. for l in self.convs:
  454. x = l(x)
  455. x = F.leaky_relu(x, LRELU_SLOPE)
  456. fmap.append(x)
  457. x = self.conv_post(x)
  458. fmap.append(x)
  459. x = torch.flatten(x, 1, -1)
  460. return x, fmap
  461. class DiscriminatorS(nn.Module):
  462. def __init__(self, use_spectral_norm=False):
  463. super(DiscriminatorS, self).__init__()
  464. norm_f = weight_norm if use_spectral_norm == False else spectral_norm
  465. self.convs = nn.ModuleList(
  466. [
  467. norm_f(Conv1d(1, 16, 15, 1, padding=7)),
  468. norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
  469. norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
  470. norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
  471. norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
  472. norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
  473. ]
  474. )
  475. self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
  476. def forward(self, x):
  477. fmap = []
  478. for l in self.convs:
  479. x = l(x)
  480. x = F.leaky_relu(x, LRELU_SLOPE)
  481. fmap.append(x)
  482. x = self.conv_post(x)
  483. fmap.append(x)
  484. x = torch.flatten(x, 1, -1)
  485. return x, fmap
  486. class EnsembleDiscriminator(nn.Module):
  487. def __init__(self, use_spectral_norm=False):
  488. super(EnsembleDiscriminator, self).__init__()
  489. periods = [2, 3, 5, 7, 11]
  490. discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
  491. discs = discs + [
  492. DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
  493. ]
  494. self.discriminators = nn.ModuleList(discs)
  495. def forward(self, y, y_hat):
  496. y_d_rs = []
  497. y_d_gs = []
  498. fmap_rs = []
  499. fmap_gs = []
  500. for i, d in enumerate(self.discriminators):
  501. y_d_r, fmap_r = d(y)
  502. y_d_g, fmap_g = d(y_hat)
  503. y_d_rs.append(y_d_r)
  504. y_d_gs.append(y_d_g)
  505. fmap_rs.append(fmap_r)
  506. fmap_gs.append(fmap_g)
  507. return y_d_rs, y_d_gs, fmap_rs, fmap_gs