models.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660
  1. import copy
  2. import math
  3. import torch
  4. from torch import nn
  5. from torch.cuda.amp import autocast
  6. from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
  7. from torch.nn import functional as F
  8. from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
  9. from fish_speech.models.vits_decoder.modules import attentions, commons, modules
  10. from .commons import get_padding, init_weights
  11. from .mrte import MRTE
  12. from .vq_encoder import VQEncoder
  13. class TextEncoder(nn.Module):
  14. def __init__(
  15. self,
  16. out_channels,
  17. hidden_channels,
  18. filter_channels,
  19. n_heads,
  20. n_layers,
  21. kernel_size,
  22. p_dropout,
  23. latent_channels=192,
  24. ):
  25. super().__init__()
  26. self.out_channels = out_channels
  27. self.hidden_channels = hidden_channels
  28. self.filter_channels = filter_channels
  29. self.n_heads = n_heads
  30. self.n_layers = n_layers
  31. self.kernel_size = kernel_size
  32. self.p_dropout = p_dropout
  33. self.latent_channels = latent_channels
  34. self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
  35. self.encoder_ssl = attentions.Encoder(
  36. hidden_channels,
  37. filter_channels,
  38. n_heads,
  39. n_layers // 2,
  40. kernel_size,
  41. p_dropout,
  42. )
  43. self.encoder_text = attentions.Encoder(
  44. hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
  45. )
  46. self.text_embedding = nn.Embedding(
  47. 322, hidden_channels
  48. ) # We only use 264, but to make the weight happy
  49. self.mrte = MRTE()
  50. self.encoder2 = attentions.Encoder(
  51. hidden_channels,
  52. filter_channels,
  53. n_heads,
  54. n_layers // 2,
  55. kernel_size,
  56. p_dropout,
  57. )
  58. self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
  59. def forward(self, y, y_lengths, text, text_lengths, ge, test=None):
  60. y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
  61. y.dtype
  62. )
  63. y = self.ssl_proj(y * y_mask) * y_mask
  64. y = self.encoder_ssl(y * y_mask, y_mask)
  65. text_mask = torch.unsqueeze(
  66. commons.sequence_mask(text_lengths, text.size(1)), 1
  67. ).to(y.dtype)
  68. if test == 1:
  69. text[:, :] = 0
  70. text = self.text_embedding(text).transpose(1, 2)
  71. text = self.encoder_text(text * text_mask, text_mask)
  72. y = self.mrte(y, y_mask, text, text_mask, ge)
  73. y = self.encoder2(y * y_mask, y_mask)
  74. stats = self.proj(y) * y_mask
  75. m, logs = torch.split(stats, self.out_channels, dim=1)
  76. return y, m, logs, y_mask
  77. def extract_latent(self, x):
  78. x = self.ssl_proj(x)
  79. quantized, codes, commit_loss, quantized_list = self.quantizer(x)
  80. return codes.transpose(0, 1)
  81. def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
  82. quantized = self.quantizer.decode(codes)
  83. y = self.vq_proj(quantized) * y_mask
  84. y = self.encoder_ssl(y * y_mask, y_mask)
  85. y = self.mrte(y, y_mask, refer, refer_mask, ge)
  86. y = self.encoder2(y * y_mask, y_mask)
  87. stats = self.proj(y) * y_mask
  88. m, logs = torch.split(stats, self.out_channels, dim=1)
  89. return y, m, logs, y_mask, quantized
  90. class ResidualCouplingBlock(nn.Module):
  91. def __init__(
  92. self,
  93. channels,
  94. hidden_channels,
  95. kernel_size,
  96. dilation_rate,
  97. n_layers,
  98. n_flows=4,
  99. gin_channels=0,
  100. ):
  101. super().__init__()
  102. self.channels = channels
  103. self.hidden_channels = hidden_channels
  104. self.kernel_size = kernel_size
  105. self.dilation_rate = dilation_rate
  106. self.n_layers = n_layers
  107. self.n_flows = n_flows
  108. self.gin_channels = gin_channels
  109. self.flows = nn.ModuleList()
  110. for i in range(n_flows):
  111. self.flows.append(
  112. modules.ResidualCouplingLayer(
  113. channels,
  114. hidden_channels,
  115. kernel_size,
  116. dilation_rate,
  117. n_layers,
  118. gin_channels=gin_channels,
  119. mean_only=True,
  120. )
  121. )
  122. self.flows.append(modules.Flip())
  123. def forward(self, x, x_mask, g=None, reverse=False):
  124. if not reverse:
  125. for flow in self.flows:
  126. x, _ = flow(x, x_mask, g=g, reverse=reverse)
  127. else:
  128. for flow in reversed(self.flows):
  129. x = flow(x, x_mask, g=g, reverse=reverse)
  130. return x
  131. class PosteriorEncoder(nn.Module):
  132. def __init__(
  133. self,
  134. in_channels,
  135. out_channels,
  136. hidden_channels,
  137. kernel_size,
  138. dilation_rate,
  139. n_layers,
  140. gin_channels=0,
  141. ):
  142. super().__init__()
  143. self.in_channels = in_channels
  144. self.out_channels = out_channels
  145. self.hidden_channels = hidden_channels
  146. self.kernel_size = kernel_size
  147. self.dilation_rate = dilation_rate
  148. self.n_layers = n_layers
  149. self.gin_channels = gin_channels
  150. self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
  151. self.enc = modules.WN(
  152. hidden_channels,
  153. kernel_size,
  154. dilation_rate,
  155. n_layers,
  156. gin_channels=gin_channels,
  157. )
  158. self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
  159. def forward(self, x, x_lengths, g=None):
  160. if g != None:
  161. g = g.detach()
  162. x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
  163. x.dtype
  164. )
  165. x = self.pre(x) * x_mask
  166. x = self.enc(x, x_mask, g=g)
  167. stats = self.proj(x) * x_mask
  168. m, logs = torch.split(stats, self.out_channels, dim=1)
  169. z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
  170. return z, m, logs, x_mask
  171. class Generator(torch.nn.Module):
  172. def __init__(
  173. self,
  174. initial_channel,
  175. resblock,
  176. resblock_kernel_sizes,
  177. resblock_dilation_sizes,
  178. upsample_rates,
  179. upsample_initial_channel,
  180. upsample_kernel_sizes,
  181. gin_channels=0,
  182. ):
  183. super(Generator, self).__init__()
  184. self.num_kernels = len(resblock_kernel_sizes)
  185. self.num_upsamples = len(upsample_rates)
  186. self.conv_pre = Conv1d(
  187. initial_channel, upsample_initial_channel, 7, 1, padding=3
  188. )
  189. resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
  190. self.ups = nn.ModuleList()
  191. for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
  192. self.ups.append(
  193. weight_norm(
  194. ConvTranspose1d(
  195. upsample_initial_channel // (2**i),
  196. upsample_initial_channel // (2 ** (i + 1)),
  197. k,
  198. u,
  199. padding=(k - u) // 2,
  200. )
  201. )
  202. )
  203. self.resblocks = nn.ModuleList()
  204. for i in range(len(self.ups)):
  205. ch = upsample_initial_channel // (2 ** (i + 1))
  206. for j, (k, d) in enumerate(
  207. zip(resblock_kernel_sizes, resblock_dilation_sizes)
  208. ):
  209. self.resblocks.append(resblock(ch, k, d))
  210. self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
  211. self.ups.apply(init_weights)
  212. if gin_channels != 0:
  213. self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
  214. def forward(self, x, g=None):
  215. x = self.conv_pre(x)
  216. if g is not None:
  217. x = x + self.cond(g)
  218. for i in range(self.num_upsamples):
  219. x = F.leaky_relu(x, modules.LRELU_SLOPE)
  220. x = self.ups[i](x)
  221. xs = None
  222. for j in range(self.num_kernels):
  223. if xs is None:
  224. xs = self.resblocks[i * self.num_kernels + j](x)
  225. else:
  226. xs += self.resblocks[i * self.num_kernels + j](x)
  227. x = xs / self.num_kernels
  228. x = F.leaky_relu(x)
  229. x = self.conv_post(x)
  230. x = torch.tanh(x)
  231. return x
  232. def remove_weight_norm(self):
  233. print("Removing weight norm...")
  234. for l in self.ups:
  235. remove_weight_norm(l)
  236. for l in self.resblocks:
  237. l.remove_weight_norm()
  238. class DiscriminatorP(torch.nn.Module):
  239. def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
  240. super(DiscriminatorP, self).__init__()
  241. self.period = period
  242. self.use_spectral_norm = use_spectral_norm
  243. norm_f = weight_norm if use_spectral_norm == False else spectral_norm
  244. self.convs = nn.ModuleList(
  245. [
  246. norm_f(
  247. Conv2d(
  248. 1,
  249. 32,
  250. (kernel_size, 1),
  251. (stride, 1),
  252. padding=(get_padding(kernel_size, 1), 0),
  253. )
  254. ),
  255. norm_f(
  256. Conv2d(
  257. 32,
  258. 128,
  259. (kernel_size, 1),
  260. (stride, 1),
  261. padding=(get_padding(kernel_size, 1), 0),
  262. )
  263. ),
  264. norm_f(
  265. Conv2d(
  266. 128,
  267. 512,
  268. (kernel_size, 1),
  269. (stride, 1),
  270. padding=(get_padding(kernel_size, 1), 0),
  271. )
  272. ),
  273. norm_f(
  274. Conv2d(
  275. 512,
  276. 1024,
  277. (kernel_size, 1),
  278. (stride, 1),
  279. padding=(get_padding(kernel_size, 1), 0),
  280. )
  281. ),
  282. norm_f(
  283. Conv2d(
  284. 1024,
  285. 1024,
  286. (kernel_size, 1),
  287. 1,
  288. padding=(get_padding(kernel_size, 1), 0),
  289. )
  290. ),
  291. ]
  292. )
  293. self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
  294. def forward(self, x):
  295. fmap = []
  296. # 1d to 2d
  297. b, c, t = x.shape
  298. if t % self.period != 0: # pad first
  299. n_pad = self.period - (t % self.period)
  300. x = F.pad(x, (0, n_pad), "reflect")
  301. t = t + n_pad
  302. x = x.view(b, c, t // self.period, self.period)
  303. for l in self.convs:
  304. x = l(x)
  305. x = F.leaky_relu(x, modules.LRELU_SLOPE)
  306. fmap.append(x)
  307. x = self.conv_post(x)
  308. fmap.append(x)
  309. x = torch.flatten(x, 1, -1)
  310. return x, fmap
  311. class DiscriminatorS(torch.nn.Module):
  312. def __init__(self, use_spectral_norm=False):
  313. super(DiscriminatorS, self).__init__()
  314. norm_f = weight_norm if use_spectral_norm == False else spectral_norm
  315. self.convs = nn.ModuleList(
  316. [
  317. norm_f(Conv1d(1, 16, 15, 1, padding=7)),
  318. norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
  319. norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
  320. norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
  321. norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
  322. norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
  323. ]
  324. )
  325. self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
  326. def forward(self, x):
  327. fmap = []
  328. for l in self.convs:
  329. x = l(x)
  330. x = F.leaky_relu(x, modules.LRELU_SLOPE)
  331. fmap.append(x)
  332. x = self.conv_post(x)
  333. fmap.append(x)
  334. x = torch.flatten(x, 1, -1)
  335. return x, fmap
  336. class EnsembledDiscriminator(torch.nn.Module):
  337. def __init__(self, use_spectral_norm=False):
  338. super().__init__()
  339. periods = [2, 3, 5, 7, 11]
  340. discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
  341. discs = discs + [
  342. DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
  343. ]
  344. self.discriminators = nn.ModuleList(discs)
  345. def forward(self, y, y_hat):
  346. y_d_rs = []
  347. y_d_gs = []
  348. fmap_rs = []
  349. fmap_gs = []
  350. for i, d in enumerate(self.discriminators):
  351. y_d_r, fmap_r = d(y)
  352. y_d_g, fmap_g = d(y_hat)
  353. y_d_rs.append(y_d_r)
  354. y_d_gs.append(y_d_g)
  355. fmap_rs.append(fmap_r)
  356. fmap_gs.append(fmap_g)
  357. return y_d_rs, y_d_gs, fmap_rs, fmap_gs
  358. class SynthesizerTrn(nn.Module):
  359. """
  360. Synthesizer for Training
  361. """
  362. def __init__(
  363. self,
  364. *,
  365. spec_channels,
  366. segment_size,
  367. inter_channels,
  368. hidden_channels,
  369. filter_channels,
  370. n_heads,
  371. n_layers,
  372. kernel_size,
  373. p_dropout,
  374. resblock,
  375. resblock_kernel_sizes,
  376. resblock_dilation_sizes,
  377. upsample_rates,
  378. upsample_initial_channel,
  379. upsample_kernel_sizes,
  380. gin_channels=0,
  381. ):
  382. super().__init__()
  383. self.spec_channels = spec_channels
  384. self.inter_channels = inter_channels
  385. self.hidden_channels = hidden_channels
  386. self.filter_channels = filter_channels
  387. self.n_heads = n_heads
  388. self.n_layers = n_layers
  389. self.kernel_size = kernel_size
  390. self.p_dropout = p_dropout
  391. self.resblock = resblock
  392. self.resblock_kernel_sizes = resblock_kernel_sizes
  393. self.resblock_dilation_sizes = resblock_dilation_sizes
  394. self.upsample_rates = upsample_rates
  395. self.upsample_initial_channel = upsample_initial_channel
  396. self.upsample_kernel_sizes = upsample_kernel_sizes
  397. self.segment_size = segment_size
  398. self.gin_channels = gin_channels
  399. self.enc_p = TextEncoder(
  400. inter_channels,
  401. hidden_channels,
  402. filter_channels,
  403. n_heads,
  404. n_layers,
  405. kernel_size,
  406. p_dropout,
  407. )
  408. self.dec = Generator(
  409. inter_channels,
  410. resblock,
  411. resblock_kernel_sizes,
  412. resblock_dilation_sizes,
  413. upsample_rates,
  414. upsample_initial_channel,
  415. upsample_kernel_sizes,
  416. gin_channels=gin_channels,
  417. )
  418. self.enc_q = PosteriorEncoder(
  419. spec_channels,
  420. inter_channels,
  421. hidden_channels,
  422. 5,
  423. 1,
  424. 16,
  425. gin_channels=gin_channels,
  426. )
  427. self.flow = ResidualCouplingBlock(
  428. inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
  429. )
  430. self.ref_enc = modules.MelStyleEncoder(
  431. spec_channels, style_vector_dim=gin_channels
  432. )
  433. self.vq = VQEncoder()
  434. for param in self.vq.parameters():
  435. param.requires_grad = False
  436. def forward(self, audio, audio_lengths, y, y_lengths, text, text_lengths):
  437. y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
  438. y.dtype
  439. )
  440. ge = self.ref_enc(y * y_mask, y_mask)
  441. quantized = self.vq(audio, audio_lengths, sr=32000)
  442. quantized = F.interpolate(quantized, size=int(y.shape[-1]), mode="nearest")
  443. x, m_p, logs_p, y_mask = self.enc_p(
  444. quantized, y_lengths, text, text_lengths, ge
  445. )
  446. z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
  447. z_p = self.flow(z, y_mask, g=ge)
  448. z_slice, ids_slice = commons.rand_slice_segments(
  449. z, y_lengths, self.segment_size
  450. )
  451. o = self.dec(z_slice, g=ge)
  452. return (
  453. o,
  454. ids_slice,
  455. y_mask,
  456. y_mask,
  457. (z, z_p, m_p, logs_p, m_q, logs_q),
  458. )
  459. def infer(
  460. self,
  461. audio,
  462. audio_lengths,
  463. y,
  464. y_lengths,
  465. text,
  466. text_lengths,
  467. test=None,
  468. noise_scale=0.5,
  469. ):
  470. # y_lengths = audio_lengths // 640 # 640 is the hop size
  471. y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
  472. y.dtype
  473. )
  474. ge = self.ref_enc(y * y_mask, y_mask)
  475. quantized = self.vq(audio, audio_lengths, sr=32000)
  476. print(quantized.size())
  477. quantized = F.interpolate(
  478. quantized, size=int(audio.shape[-1] // 640), mode="nearest"
  479. )
  480. print(quantized.size())
  481. x, m_p, logs_p, y_mask = self.enc_p(
  482. quantized, audio_lengths, text, text_lengths, ge, test=test
  483. )
  484. z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
  485. z = self.flow(z_p, y_mask, g=ge, reverse=True)
  486. o = self.dec((z * y_mask)[:, :, :], g=ge)
  487. return o, y_mask, (z, z_p, m_p, logs_p)
  488. @torch.no_grad()
  489. def decode(self, codes, text, refer, noise_scale=0.5):
  490. ge = None
  491. if refer is not None:
  492. refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
  493. refer_mask = torch.unsqueeze(
  494. commons.sequence_mask(refer_lengths, refer.size(2)), 1
  495. ).to(refer.dtype)
  496. ge = self.ref_enc(refer * refer_mask, refer_mask)
  497. y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
  498. text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
  499. quantized = self.quantizer.decode(codes)
  500. quantized = F.interpolate(
  501. quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
  502. )
  503. x, m_p, logs_p, y_mask = self.enc_p(
  504. quantized, y_lengths, text, text_lengths, ge
  505. )
  506. z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
  507. z = self.flow(z_p, y_mask, g=ge, reverse=True)
  508. o = self.dec((z * y_mask)[:, :, :], g=ge)
  509. return o
  510. def extract_latent(self, x):
  511. ssl = self.ssl_proj(x)
  512. quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
  513. return codes.transpose(0, 1)
  514. if __name__ == "__main__":
  515. import librosa
  516. from transformers import AutoTokenizer
  517. from fish_speech.models.vqgan.spectrogram import LinearSpectrogram
  518. model = SynthesizerTrn(
  519. spec_channels=1025,
  520. segment_size=20480 // 640,
  521. inter_channels=192,
  522. hidden_channels=192,
  523. filter_channels=768,
  524. n_heads=2,
  525. n_layers=6,
  526. kernel_size=3,
  527. p_dropout=0.1,
  528. resblock="1",
  529. resblock_kernel_sizes=[3, 7, 11],
  530. resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
  531. upsample_rates=[10, 8, 2, 2, 2],
  532. upsample_initial_channel=512,
  533. upsample_kernel_sizes=[16, 16, 8, 2, 2],
  534. gin_channels=512,
  535. )
  536. ckpt = "./checkpoints/s2_big2k1_158000.pth"
  537. # Try to load the model
  538. print(f"Loading model from {ckpt}")
  539. checkpoint = torch.load(ckpt, map_location="cpu", weights_only=True)
  540. print(model.load_state_dict(checkpoint, strict=False))
  541. # Test
  542. ref_audio = librosa.load(
  543. "data/source/Genshin/Chinese/五郎/vo_DQAQ010_15_gorou_07.wav", sr=32000
  544. )[0]
  545. input_audio = librosa.load(
  546. "data/source/Genshin/Chinese/空/vo_FDAQ003_46_hero_02.wav", sr=32000
  547. )[0]
  548. # ref_audio = input_audio
  549. text = "(现在看来花瓶里的水并不是用来隐藏水迹,而是在莉莉安与考威尔的争斗中不小心撞破的…)"
  550. encoded_text = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
  551. spec = LinearSpectrogram(n_fft=2048, hop_length=640, win_length=2048)
  552. ref_audio = torch.tensor(ref_audio).unsqueeze(0).unsqueeze(0)
  553. ref_spec = spec(ref_audio)
  554. input_audio = torch.tensor(input_audio).unsqueeze(0).unsqueeze(0)
  555. text = encoded_text(text, return_tensors="pt")["input_ids"]
  556. print(ref_audio.size(), ref_spec.size(), input_audio.size(), text.size())
  557. o, y_mask, (z, z_p, m_p, logs_p) = model.infer(
  558. input_audio,
  559. torch.LongTensor([input_audio.size(2)]),
  560. ref_spec,
  561. torch.LongTensor([ref_spec.size(2)]),
  562. text,
  563. torch.LongTensor([text.size(1)]),
  564. )
  565. print(o.size(), y_mask.size(), z.size(), z_p.size(), m_p.size(), logs_p.size())
  566. # Save output
  567. import soundfile as sf
  568. sf.write("output.wav", o.squeeze().detach().numpy(), 32000)