models.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693
  1. import copy
  2. import math
  3. import torch
  4. from torch import nn
  5. from torch.cuda.amp import autocast
  6. from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
  7. from torch.nn import functional as F
  8. from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
  9. from fish_speech.models.vqgan.modules import attentions, commons, modules
  10. from fish_speech.models.vqgan.modules.commons import get_padding, init_weights
  11. from fish_speech.models.vqgan.modules.rvq import DownsampleResidualVectorQuantizer
  12. class FeatureEncoder(nn.Module):
  13. def __init__(
  14. self,
  15. spec_channels,
  16. out_channels,
  17. hidden_channels,
  18. n_layers,
  19. kernel_size,
  20. p_dropout,
  21. codebook_size=1024,
  22. num_codebooks=2,
  23. gin_channels=0,
  24. aux_spec_channels=None,
  25. ):
  26. super().__init__()
  27. self.out_channels = out_channels
  28. self.hidden_channels = hidden_channels
  29. self.n_layers = n_layers
  30. self.kernel_size = kernel_size
  31. self.p_dropout = p_dropout
  32. if aux_spec_channels is None:
  33. aux_spec_channels = spec_channels
  34. self.spec_proj = nn.Conv1d(spec_channels, hidden_channels, 1)
  35. self.encoder = modules.WN(
  36. hidden_channels=hidden_channels,
  37. kernel_size=kernel_size,
  38. dilation_rate=1,
  39. n_layers=n_layers // 2,
  40. )
  41. self.vq = DownsampleResidualVectorQuantizer(
  42. input_dim=hidden_channels,
  43. n_codebooks=num_codebooks,
  44. codebook_size=codebook_size,
  45. codebook_dim=hidden_channels,
  46. min_quantizers=num_codebooks,
  47. downsample_factor=(2,),
  48. )
  49. self.decoder = modules.WN(
  50. hidden_channels=hidden_channels,
  51. kernel_size=kernel_size,
  52. dilation_rate=1,
  53. n_layers=n_layers // 2,
  54. gin_channels=gin_channels,
  55. )
  56. self.aux_decoder = modules.WN(
  57. hidden_channels=hidden_channels,
  58. kernel_size=kernel_size,
  59. dilation_rate=1,
  60. n_layers=4,
  61. )
  62. self.aux_proj = nn.Conv1d(hidden_channels, aux_spec_channels, 1)
  63. self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
  64. def forward(self, y, y_lengths, ge):
  65. y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
  66. y.dtype
  67. )
  68. y = self.spec_proj(y * y_mask) * y_mask
  69. y = self.encoder(y, y_mask) * y_mask
  70. z, indices, loss_vq = self.vq(y)
  71. y = self.decoder(z, y_mask, g=ge) * y_mask
  72. decoded_aux_mel = self.aux_decoder(y, y_mask)
  73. decoded_aux_mel = self.aux_proj(decoded_aux_mel) * y_mask
  74. stats = self.proj(y) * y_mask
  75. m, logs = torch.split(stats, self.out_channels, dim=1)
  76. return y, m, logs, y_mask, loss_vq, decoded_aux_mel
  77. class ResidualCouplingBlock(nn.Module):
  78. def __init__(
  79. self,
  80. channels,
  81. hidden_channels,
  82. kernel_size,
  83. dilation_rate,
  84. n_layers,
  85. n_flows=4,
  86. gin_channels=0,
  87. ):
  88. super().__init__()
  89. self.channels = channels
  90. self.hidden_channels = hidden_channels
  91. self.kernel_size = kernel_size
  92. self.dilation_rate = dilation_rate
  93. self.n_layers = n_layers
  94. self.n_flows = n_flows
  95. self.gin_channels = gin_channels
  96. self.flows = nn.ModuleList()
  97. for i in range(n_flows):
  98. self.flows.append(
  99. modules.ResidualCouplingLayer(
  100. channels,
  101. hidden_channels,
  102. kernel_size,
  103. dilation_rate,
  104. n_layers,
  105. gin_channels=gin_channels,
  106. mean_only=True,
  107. )
  108. )
  109. self.flows.append(modules.Flip())
  110. def forward(self, x, x_mask, g=None, reverse=False):
  111. if not reverse:
  112. for flow in self.flows:
  113. x, _ = flow(x, x_mask, g=g, reverse=reverse)
  114. else:
  115. for flow in reversed(self.flows):
  116. x = flow(x, x_mask, g=g, reverse=reverse)
  117. return x
  118. class PosteriorEncoder(nn.Module):
  119. def __init__(
  120. self,
  121. in_channels,
  122. out_channels,
  123. hidden_channels,
  124. kernel_size,
  125. dilation_rate,
  126. n_layers,
  127. gin_channels=0,
  128. ):
  129. super().__init__()
  130. self.in_channels = in_channels
  131. self.out_channels = out_channels
  132. self.hidden_channels = hidden_channels
  133. self.kernel_size = kernel_size
  134. self.dilation_rate = dilation_rate
  135. self.n_layers = n_layers
  136. self.gin_channels = gin_channels
  137. self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
  138. self.enc = modules.WN(
  139. hidden_channels,
  140. kernel_size,
  141. dilation_rate,
  142. n_layers,
  143. gin_channels=gin_channels,
  144. )
  145. self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
  146. def forward(self, x, x_lengths, g=None):
  147. g = g.detach()
  148. x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
  149. x.dtype
  150. )
  151. x = self.pre(x) * x_mask
  152. x = self.enc(x, x_mask, g=g)
  153. stats = self.proj(x) * x_mask
  154. m, logs = torch.split(stats, self.out_channels, dim=1)
  155. z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
  156. return z, m, logs, x_mask
  157. class WNEncoder(nn.Module):
  158. def __init__(
  159. self,
  160. in_channels,
  161. out_channels,
  162. hidden_channels,
  163. kernel_size,
  164. dilation_rate,
  165. n_layers,
  166. gin_channels=0,
  167. ):
  168. super().__init__()
  169. self.in_channels = in_channels
  170. self.out_channels = out_channels
  171. self.hidden_channels = hidden_channels
  172. self.kernel_size = kernel_size
  173. self.dilation_rate = dilation_rate
  174. self.n_layers = n_layers
  175. self.gin_channels = gin_channels
  176. self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
  177. self.enc = modules.WN(
  178. hidden_channels,
  179. kernel_size,
  180. dilation_rate,
  181. n_layers,
  182. gin_channels=gin_channels,
  183. )
  184. self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
  185. self.norm = modules.LayerNorm(out_channels)
  186. def forward(self, x, x_lengths, g=None):
  187. x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
  188. x.dtype
  189. )
  190. x = self.pre(x) * x_mask
  191. x = self.enc(x, x_mask, g=g)
  192. out = self.proj(x) * x_mask
  193. out = self.norm(out)
  194. return out
  195. class Generator(torch.nn.Module):
  196. def __init__(
  197. self,
  198. initial_channel,
  199. resblock,
  200. resblock_kernel_sizes,
  201. resblock_dilation_sizes,
  202. upsample_rates,
  203. upsample_initial_channel,
  204. upsample_kernel_sizes,
  205. gin_channels=0,
  206. ):
  207. super(Generator, self).__init__()
  208. self.num_kernels = len(resblock_kernel_sizes)
  209. self.num_upsamples = len(upsample_rates)
  210. self.conv_pre = Conv1d(
  211. initial_channel, upsample_initial_channel, 7, 1, padding=3
  212. )
  213. resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
  214. self.ups = nn.ModuleList()
  215. for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
  216. self.ups.append(
  217. weight_norm(
  218. ConvTranspose1d(
  219. upsample_initial_channel // (2**i),
  220. upsample_initial_channel // (2 ** (i + 1)),
  221. k,
  222. u,
  223. padding=(k - u) // 2,
  224. )
  225. )
  226. )
  227. self.resblocks = nn.ModuleList()
  228. for i in range(len(self.ups)):
  229. ch = upsample_initial_channel // (2 ** (i + 1))
  230. for j, (k, d) in enumerate(
  231. zip(resblock_kernel_sizes, resblock_dilation_sizes)
  232. ):
  233. self.resblocks.append(resblock(ch, k, d))
  234. self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
  235. self.ups.apply(init_weights)
  236. if gin_channels != 0:
  237. self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
  238. def forward(self, x, g=None):
  239. x = self.conv_pre(x)
  240. if g is not None:
  241. x = x + self.cond(g)
  242. for i in range(self.num_upsamples):
  243. x = F.leaky_relu(x, modules.LRELU_SLOPE)
  244. x = self.ups[i](x)
  245. xs = None
  246. for j in range(self.num_kernels):
  247. if xs is None:
  248. xs = self.resblocks[i * self.num_kernels + j](x)
  249. else:
  250. xs += self.resblocks[i * self.num_kernels + j](x)
  251. x = xs / self.num_kernels
  252. x = F.leaky_relu(x)
  253. x = self.conv_post(x)
  254. x = torch.tanh(x)
  255. return x
  256. def remove_weight_norm(self):
  257. print("Removing weight norm...")
  258. for l in self.ups:
  259. remove_weight_norm(l)
  260. for l in self.resblocks:
  261. l.remove_weight_norm()
  262. class DiscriminatorP(torch.nn.Module):
  263. def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
  264. super(DiscriminatorP, self).__init__()
  265. self.period = period
  266. self.use_spectral_norm = use_spectral_norm
  267. norm_f = weight_norm if use_spectral_norm == False else spectral_norm
  268. self.convs = nn.ModuleList(
  269. [
  270. norm_f(
  271. Conv2d(
  272. 1,
  273. 32,
  274. (kernel_size, 1),
  275. (stride, 1),
  276. padding=(get_padding(kernel_size, 1), 0),
  277. )
  278. ),
  279. norm_f(
  280. Conv2d(
  281. 32,
  282. 128,
  283. (kernel_size, 1),
  284. (stride, 1),
  285. padding=(get_padding(kernel_size, 1), 0),
  286. )
  287. ),
  288. norm_f(
  289. Conv2d(
  290. 128,
  291. 512,
  292. (kernel_size, 1),
  293. (stride, 1),
  294. padding=(get_padding(kernel_size, 1), 0),
  295. )
  296. ),
  297. norm_f(
  298. Conv2d(
  299. 512,
  300. 1024,
  301. (kernel_size, 1),
  302. (stride, 1),
  303. padding=(get_padding(kernel_size, 1), 0),
  304. )
  305. ),
  306. norm_f(
  307. Conv2d(
  308. 1024,
  309. 1024,
  310. (kernel_size, 1),
  311. 1,
  312. padding=(get_padding(kernel_size, 1), 0),
  313. )
  314. ),
  315. ]
  316. )
  317. self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
  318. def forward(self, x):
  319. fmap = []
  320. # 1d to 2d
  321. b, c, t = x.shape
  322. if t % self.period != 0: # pad first
  323. n_pad = self.period - (t % self.period)
  324. x = F.pad(x, (0, n_pad), "reflect")
  325. t = t + n_pad
  326. x = x.view(b, c, t // self.period, self.period)
  327. for l in self.convs:
  328. x = l(x)
  329. x = F.leaky_relu(x, modules.LRELU_SLOPE)
  330. fmap.append(x)
  331. x = self.conv_post(x)
  332. fmap.append(x)
  333. x = torch.flatten(x, 1, -1)
  334. return x, fmap
  335. class DiscriminatorS(torch.nn.Module):
  336. def __init__(self, use_spectral_norm=False):
  337. super(DiscriminatorS, self).__init__()
  338. norm_f = weight_norm if use_spectral_norm == False else spectral_norm
  339. self.convs = nn.ModuleList(
  340. [
  341. norm_f(Conv1d(1, 16, 15, 1, padding=7)),
  342. norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
  343. norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
  344. norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
  345. norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
  346. norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
  347. ]
  348. )
  349. self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
  350. def forward(self, x):
  351. fmap = []
  352. for l in self.convs:
  353. x = l(x)
  354. x = F.leaky_relu(x, modules.LRELU_SLOPE)
  355. fmap.append(x)
  356. x = self.conv_post(x)
  357. fmap.append(x)
  358. x = torch.flatten(x, 1, -1)
  359. return x, fmap
  360. class EnsembledDiscriminator(torch.nn.Module):
  361. def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False):
  362. super(EnsembledDiscriminator, self).__init__()
  363. discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
  364. discs = discs + [
  365. DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
  366. ]
  367. self.discriminators = nn.ModuleList(discs)
  368. def forward(self, y, y_hat):
  369. y_d_rs = []
  370. y_d_gs = []
  371. fmap_rs = []
  372. fmap_gs = []
  373. for i, d in enumerate(self.discriminators):
  374. y_d_r, fmap_r = d(y)
  375. y_d_g, fmap_g = d(y_hat)
  376. y_d_rs.append(y_d_r)
  377. y_d_gs.append(y_d_g)
  378. fmap_rs.append(fmap_r)
  379. fmap_gs.append(fmap_g)
  380. return y_d_rs, y_d_gs, fmap_rs, fmap_gs
  381. class SynthesizerTrn(nn.Module):
  382. """
  383. Synthesizer for Training
  384. """
  385. def __init__(
  386. self,
  387. *,
  388. spec_channels,
  389. segment_size,
  390. inter_channels,
  391. prior_hidden_channels,
  392. prior_n_layers,
  393. posterior_hidden_channels,
  394. posterior_n_layers,
  395. kernel_size,
  396. p_dropout,
  397. resblock,
  398. resblock_kernel_sizes,
  399. resblock_dilation_sizes,
  400. upsample_rates,
  401. upsample_initial_channel,
  402. upsample_kernel_sizes,
  403. gin_channels=0,
  404. freeze_quantizer=False,
  405. codebook_size=1024,
  406. num_codebooks=2,
  407. freeze_decoder=False,
  408. freeze_posterior_encoder=False,
  409. aux_spec_channels=None,
  410. ):
  411. super().__init__()
  412. self.spec_channels = spec_channels
  413. self.inter_channels = inter_channels
  414. self.prior_hidden_channels = prior_hidden_channels
  415. self.prior_n_layers = prior_n_layers
  416. self.posterior_hidden_channels = posterior_hidden_channels
  417. self.posterior_n_layers = posterior_n_layers
  418. self.kernel_size = kernel_size
  419. self.p_dropout = p_dropout
  420. self.resblock = resblock
  421. self.resblock_kernel_sizes = resblock_kernel_sizes
  422. self.resblock_dilation_sizes = resblock_dilation_sizes
  423. self.upsample_rates = upsample_rates
  424. self.upsample_initial_channel = upsample_initial_channel
  425. self.upsample_kernel_sizes = upsample_kernel_sizes
  426. self.segment_size = segment_size
  427. self.gin_channels = gin_channels
  428. self.enc_p = FeatureEncoder(
  429. spec_channels=spec_channels,
  430. out_channels=inter_channels,
  431. hidden_channels=prior_hidden_channels,
  432. n_layers=prior_n_layers,
  433. kernel_size=kernel_size,
  434. p_dropout=p_dropout,
  435. codebook_size=codebook_size,
  436. num_codebooks=num_codebooks,
  437. gin_channels=gin_channels,
  438. aux_spec_channels=aux_spec_channels,
  439. )
  440. self.dec = Generator(
  441. initial_channel=inter_channels,
  442. resblock=resblock,
  443. resblock_kernel_sizes=resblock_kernel_sizes,
  444. resblock_dilation_sizes=resblock_dilation_sizes,
  445. upsample_rates=upsample_rates,
  446. upsample_initial_channel=upsample_initial_channel,
  447. upsample_kernel_sizes=upsample_kernel_sizes,
  448. gin_channels=gin_channels,
  449. )
  450. self.enc_q = PosteriorEncoder(
  451. in_channels=spec_channels,
  452. out_channels=inter_channels,
  453. hidden_channels=posterior_hidden_channels,
  454. kernel_size=5,
  455. dilation_rate=1,
  456. n_layers=posterior_n_layers,
  457. gin_channels=gin_channels,
  458. )
  459. self.flow = ResidualCouplingBlock(
  460. inter_channels,
  461. posterior_hidden_channels,
  462. 5,
  463. 1,
  464. 4,
  465. gin_channels=gin_channels,
  466. )
  467. self.ref_enc = modules.MelStyleEncoder(
  468. spec_channels, style_vector_dim=gin_channels
  469. )
  470. if freeze_quantizer:
  471. self.enc_p.spec_proj.requires_grad_(False)
  472. self.enc_p.encoder.requires_grad_(False)
  473. self.enc_p.vq.requires_grad_(False)
  474. if freeze_decoder:
  475. self.dec.requires_grad_(False)
  476. if freeze_posterior_encoder:
  477. self.enc_q.requires_grad_(False)
  478. def forward(self, y, y_lengths):
  479. y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
  480. y.dtype
  481. )
  482. ge = self.ref_enc(y * y_mask, y_mask)
  483. x, m_p, logs_p, y_mask, quantized, decoded_aux_mel = self.enc_p(
  484. y, y_lengths, ge
  485. )
  486. z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
  487. z_p = self.flow(z, y_mask, g=ge)
  488. z_slice, ids_slice = commons.rand_slice_segments(
  489. z, y_lengths, self.segment_size
  490. )
  491. o = self.dec(z_slice, g=ge)
  492. return (
  493. o,
  494. ids_slice,
  495. y_mask,
  496. y_mask,
  497. (z, z_p, m_p, logs_p, m_q, logs_q),
  498. quantized,
  499. decoded_aux_mel,
  500. )
  501. def infer(self, y, y_lengths, noise_scale=0.5):
  502. y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
  503. y.dtype
  504. )
  505. ge = self.ref_enc(y * y_mask, y_mask)
  506. x, m_p, logs_p, y_mask, _, _ = self.enc_p(y, y_lengths, ge)
  507. z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
  508. z = self.flow(z_p, y_mask, g=ge, reverse=True)
  509. o = self.dec((z * y_mask)[:, :, :], g=ge)
  510. return o, y_mask, (z, z_p, m_p, logs_p)
  511. def infer_posterior(self, y, y_lengths):
  512. y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
  513. y.dtype
  514. )
  515. ge = self.ref_enc(y * y_mask, y_mask)
  516. z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
  517. o = self.dec(z * y_mask, g=ge)
  518. return o, y_mask, (z, m_q, logs_q)
  519. # @torch.no_grad()
  520. # def decode(self, codes, text, refer, noise_scale=0.5):
  521. # refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
  522. # refer_mask = torch.unsqueeze(
  523. # commons.sequence_mask(refer_lengths, refer.size(2)), 1
  524. # ).to(refer.dtype)
  525. # ge = self.ref_enc(refer * refer_mask, refer_mask)
  526. # y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
  527. # text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
  528. # quantized = self.quantizer.decode(codes)
  529. # if self.semantic_frame_rate == "25hz":
  530. # quantized = F.interpolate(
  531. # quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
  532. # )
  533. # x, m_p, logs_p, y_mask = self.enc_p(
  534. # quantized, y_lengths, text, text_lengths, ge
  535. # )
  536. # z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
  537. # z = self.flow(z_p, y_mask, g=ge, reverse=True)
  538. # o = self.dec((z * y_mask)[:, :, :], g=ge)
  539. # return o
  540. # def extract_latent(self, x):
  541. # ssl = self.ssl_proj(x)
  542. # quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
  543. # return codes.transpose(0, 1)
  544. if __name__ == "__main__":
  545. model = SynthesizerTrn(
  546. spec_channels=1025,
  547. segment_size=20480,
  548. inter_channels=192,
  549. prior_hidden_channels=384,
  550. posterior_hidden_channels=192,
  551. prior_n_layers=16,
  552. posterior_n_layers=16,
  553. kernel_size=3,
  554. p_dropout=0.1,
  555. resblock="1",
  556. resblock_kernel_sizes=[3, 7, 11],
  557. resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
  558. upsample_rates=[10, 8, 2, 2, 2],
  559. upsample_initial_channel=512,
  560. upsample_kernel_sizes=[16, 16, 8, 2, 2],
  561. gin_channels=512,
  562. freeze_quantizer=True,
  563. )
  564. state_dict_g = torch.load("checkpoints/gpt_sovits_g_488k.pth", map_location="cpu")
  565. state_dict_d = torch.load("checkpoints/gpt_sovits_d_488k.pth", map_location="cpu")
  566. keys = set(model.state_dict().keys())
  567. state_dict_g = {
  568. k: v for k, v in state_dict_g.items() if k in keys and "enc_p" not in k
  569. }
  570. new_state = {}
  571. for k, v in state_dict_g.items():
  572. new_state["generator." + k] = v
  573. for k, v in state_dict_d.items():
  574. new_state["discriminator." + k] = v
  575. torch.save(new_state, "checkpoints/gpt_sovits_488k.pth")
  576. exit()
  577. # print(EnsembledDiscriminator().load_state_dict(state_dict_d, strict=False))
  578. print(model.load_state_dict(state_dict_g, strict=False))
  579. # y = torch.randn(3, 1025, 20480)
  580. # y_lengths = torch.tensor([20480, 19000, 18000])
  581. import librosa
  582. import soundfile as sf
  583. from fish_speech.models.vqgan.spectrogram import LinearSpectrogram
  584. spec = LinearSpectrogram(
  585. n_fft=2048, win_length=2048, hop_length=640, mode="pow2_sqrt"
  586. )
  587. audio, _ = librosa.load(
  588. "/***REMOVED***/workspace/llm-multimodal-test/data/Rail_ZH/星/dbc16cc114ca1700.wav",
  589. sr=32000,
  590. )
  591. y = spec(torch.tensor(audio).unsqueeze(0))
  592. y_lengths = torch.tensor([y.size(2)])
  593. o, ids_slice, y_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q), quantized = model(
  594. y, y_lengths
  595. )
  596. print(o.shape)
  597. o, y_mask, (z, z_p, m_p, logs_p) = model.infer(y, y_lengths)
  598. print(o.shape)
  599. o, y_mask, (z, m_q, logs_q) = model.infer_posterior(y, y_lengths)
  600. print(o.shape)
  601. o = o.squeeze(0).T.detach().cpu().numpy()
  602. sf.write("test.wav", o, 32000)