models.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677
  1. import torch
  2. from torch import nn
  3. from torch.nn import Conv1d, Conv2d, ConvTranspose1d
  4. from torch.nn import functional as F
  5. from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
  6. from fish_speech.models.vits_decoder.modules import attentions, commons, modules
  7. from .commons import get_padding, init_weights
  8. from .mrte import MRTE
  9. from .vq_encoder import VQEncoder
  10. class TextEncoder(nn.Module):
  11. def __init__(
  12. self,
  13. out_channels,
  14. hidden_channels,
  15. filter_channels,
  16. n_heads,
  17. n_layers,
  18. kernel_size,
  19. p_dropout,
  20. latent_channels=192,
  21. codebook_size=264,
  22. ):
  23. super().__init__()
  24. self.out_channels = out_channels
  25. self.hidden_channels = hidden_channels
  26. self.filter_channels = filter_channels
  27. self.n_heads = n_heads
  28. self.n_layers = n_layers
  29. self.kernel_size = kernel_size
  30. self.p_dropout = p_dropout
  31. self.latent_channels = latent_channels
  32. self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
  33. self.encoder_ssl = attentions.Encoder(
  34. hidden_channels,
  35. filter_channels,
  36. n_heads,
  37. n_layers // 2,
  38. kernel_size,
  39. p_dropout,
  40. )
  41. self.encoder_text = attentions.Encoder(
  42. hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
  43. )
  44. self.text_embedding = nn.Embedding(codebook_size, hidden_channels)
  45. self.mrte = MRTE()
  46. self.encoder2 = attentions.Encoder(
  47. hidden_channels,
  48. filter_channels,
  49. n_heads,
  50. n_layers // 2,
  51. kernel_size,
  52. p_dropout,
  53. )
  54. self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
  55. def forward(self, y, y_lengths, text, text_lengths, ge):
  56. y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
  57. y.dtype
  58. )
  59. y = self.ssl_proj(y * y_mask) * y_mask
  60. y = self.encoder_ssl(y * y_mask, y_mask)
  61. text_mask = torch.unsqueeze(
  62. commons.sequence_mask(text_lengths, text.size(1)), 1
  63. ).to(y.dtype)
  64. text = self.text_embedding(text).transpose(1, 2)
  65. text = self.encoder_text(text * text_mask, text_mask)
  66. y = self.mrte(y, y_mask, text, text_mask, ge)
  67. y = self.encoder2(y * y_mask, y_mask)
  68. stats = self.proj(y) * y_mask
  69. m, logs = torch.split(stats, self.out_channels, dim=1)
  70. return y, m, logs, y_mask
  71. def extract_latent(self, x):
  72. x = self.ssl_proj(x)
  73. quantized, codes, commit_loss, quantized_list = self.quantizer(x)
  74. return codes.transpose(0, 1)
  75. def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
  76. quantized = self.quantizer.decode(codes)
  77. y = self.vq_proj(quantized) * y_mask
  78. y = self.encoder_ssl(y * y_mask, y_mask)
  79. y = self.mrte(y, y_mask, refer, refer_mask, ge)
  80. y = self.encoder2(y * y_mask, y_mask)
  81. stats = self.proj(y) * y_mask
  82. m, logs = torch.split(stats, self.out_channels, dim=1)
  83. return y, m, logs, y_mask, quantized
  84. class ResidualCouplingBlock(nn.Module):
  85. def __init__(
  86. self,
  87. channels,
  88. hidden_channels,
  89. kernel_size,
  90. dilation_rate,
  91. n_layers,
  92. n_flows=4,
  93. gin_channels=0,
  94. ):
  95. super().__init__()
  96. self.channels = channels
  97. self.hidden_channels = hidden_channels
  98. self.kernel_size = kernel_size
  99. self.dilation_rate = dilation_rate
  100. self.n_layers = n_layers
  101. self.n_flows = n_flows
  102. self.gin_channels = gin_channels
  103. self.flows = nn.ModuleList()
  104. for i in range(n_flows):
  105. self.flows.append(
  106. modules.ResidualCouplingLayer(
  107. channels,
  108. hidden_channels,
  109. kernel_size,
  110. dilation_rate,
  111. n_layers,
  112. gin_channels=gin_channels,
  113. mean_only=True,
  114. )
  115. )
  116. self.flows.append(modules.Flip())
  117. def forward(self, x, x_mask, g=None, reverse=False):
  118. if not reverse:
  119. for flow in self.flows:
  120. x, _ = flow(x, x_mask, g=g, reverse=reverse)
  121. else:
  122. for flow in reversed(self.flows):
  123. x = flow(x, x_mask, g=g, reverse=reverse)
  124. return x
  125. class PosteriorEncoder(nn.Module):
  126. def __init__(
  127. self,
  128. in_channels,
  129. out_channels,
  130. hidden_channels,
  131. kernel_size,
  132. dilation_rate,
  133. n_layers,
  134. gin_channels=0,
  135. ):
  136. super().__init__()
  137. self.in_channels = in_channels
  138. self.out_channels = out_channels
  139. self.hidden_channels = hidden_channels
  140. self.kernel_size = kernel_size
  141. self.dilation_rate = dilation_rate
  142. self.n_layers = n_layers
  143. self.gin_channels = gin_channels
  144. self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
  145. self.enc = modules.WN(
  146. hidden_channels,
  147. kernel_size,
  148. dilation_rate,
  149. n_layers,
  150. gin_channels=gin_channels,
  151. )
  152. self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
  153. def forward(self, x, x_lengths, g=None):
  154. if g != None:
  155. g = g.detach()
  156. x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
  157. x.dtype
  158. )
  159. x = self.pre(x) * x_mask
  160. x = self.enc(x, x_mask, g=g)
  161. stats = self.proj(x) * x_mask
  162. m, logs = torch.split(stats, self.out_channels, dim=1)
  163. z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
  164. return z, m, logs, x_mask
  165. class Generator(torch.nn.Module):
  166. def __init__(
  167. self,
  168. initial_channel,
  169. resblock,
  170. resblock_kernel_sizes,
  171. resblock_dilation_sizes,
  172. upsample_rates,
  173. upsample_initial_channel,
  174. upsample_kernel_sizes,
  175. gin_channels=0,
  176. ):
  177. super(Generator, self).__init__()
  178. self.num_kernels = len(resblock_kernel_sizes)
  179. self.num_upsamples = len(upsample_rates)
  180. self.conv_pre = Conv1d(
  181. initial_channel, upsample_initial_channel, 7, 1, padding=3
  182. )
  183. resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
  184. self.ups = nn.ModuleList()
  185. for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
  186. self.ups.append(
  187. weight_norm(
  188. ConvTranspose1d(
  189. upsample_initial_channel // (2**i),
  190. upsample_initial_channel // (2 ** (i + 1)),
  191. k,
  192. u,
  193. padding=(k - u) // 2,
  194. )
  195. )
  196. )
  197. self.resblocks = nn.ModuleList()
  198. for i in range(len(self.ups)):
  199. ch = upsample_initial_channel // (2 ** (i + 1))
  200. for j, (k, d) in enumerate(
  201. zip(resblock_kernel_sizes, resblock_dilation_sizes)
  202. ):
  203. self.resblocks.append(resblock(ch, k, d))
  204. self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
  205. self.ups.apply(init_weights)
  206. if gin_channels != 0:
  207. self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
  208. def forward(self, x, g=None):
  209. x = self.conv_pre(x)
  210. if g is not None:
  211. x = x + self.cond(g)
  212. for i in range(self.num_upsamples):
  213. x = F.leaky_relu(x, modules.LRELU_SLOPE)
  214. x = self.ups[i](x)
  215. xs = None
  216. for j in range(self.num_kernels):
  217. if xs is None:
  218. xs = self.resblocks[i * self.num_kernels + j](x)
  219. else:
  220. xs += self.resblocks[i * self.num_kernels + j](x)
  221. x = xs / self.num_kernels
  222. x = F.leaky_relu(x)
  223. x = self.conv_post(x)
  224. x = torch.tanh(x)
  225. return x
  226. def remove_weight_norm(self):
  227. print("Removing weight norm...")
  228. for l in self.ups:
  229. remove_weight_norm(l)
  230. for l in self.resblocks:
  231. l.remove_weight_norm()
  232. class DiscriminatorP(torch.nn.Module):
  233. def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
  234. super(DiscriminatorP, self).__init__()
  235. self.period = period
  236. self.use_spectral_norm = use_spectral_norm
  237. norm_f = weight_norm if use_spectral_norm == False else spectral_norm
  238. self.convs = nn.ModuleList(
  239. [
  240. norm_f(
  241. Conv2d(
  242. 1,
  243. 32,
  244. (kernel_size, 1),
  245. (stride, 1),
  246. padding=(get_padding(kernel_size, 1), 0),
  247. )
  248. ),
  249. norm_f(
  250. Conv2d(
  251. 32,
  252. 128,
  253. (kernel_size, 1),
  254. (stride, 1),
  255. padding=(get_padding(kernel_size, 1), 0),
  256. )
  257. ),
  258. norm_f(
  259. Conv2d(
  260. 128,
  261. 512,
  262. (kernel_size, 1),
  263. (stride, 1),
  264. padding=(get_padding(kernel_size, 1), 0),
  265. )
  266. ),
  267. norm_f(
  268. Conv2d(
  269. 512,
  270. 1024,
  271. (kernel_size, 1),
  272. (stride, 1),
  273. padding=(get_padding(kernel_size, 1), 0),
  274. )
  275. ),
  276. norm_f(
  277. Conv2d(
  278. 1024,
  279. 1024,
  280. (kernel_size, 1),
  281. 1,
  282. padding=(get_padding(kernel_size, 1), 0),
  283. )
  284. ),
  285. ]
  286. )
  287. self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
  288. def forward(self, x):
  289. fmap = []
  290. # 1d to 2d
  291. b, c, t = x.shape
  292. if t % self.period != 0: # pad first
  293. n_pad = self.period - (t % self.period)
  294. x = F.pad(x, (0, n_pad), "reflect")
  295. t = t + n_pad
  296. x = x.view(b, c, t // self.period, self.period)
  297. for l in self.convs:
  298. x = l(x)
  299. x = F.leaky_relu(x, modules.LRELU_SLOPE)
  300. fmap.append(x)
  301. x = self.conv_post(x)
  302. fmap.append(x)
  303. x = torch.flatten(x, 1, -1)
  304. return x, fmap
  305. class DiscriminatorS(torch.nn.Module):
  306. def __init__(self, use_spectral_norm=False):
  307. super(DiscriminatorS, self).__init__()
  308. norm_f = weight_norm if use_spectral_norm == False else spectral_norm
  309. self.convs = nn.ModuleList(
  310. [
  311. norm_f(Conv1d(1, 16, 15, 1, padding=7)),
  312. norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
  313. norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
  314. norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
  315. norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
  316. norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
  317. ]
  318. )
  319. self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
  320. def forward(self, x):
  321. fmap = []
  322. for l in self.convs:
  323. x = l(x)
  324. x = F.leaky_relu(x, modules.LRELU_SLOPE)
  325. fmap.append(x)
  326. x = self.conv_post(x)
  327. fmap.append(x)
  328. x = torch.flatten(x, 1, -1)
  329. return x, fmap
  330. class EnsembledDiscriminator(torch.nn.Module):
  331. def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False):
  332. super().__init__()
  333. discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
  334. discs = discs + [
  335. DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
  336. ]
  337. self.discriminators = nn.ModuleList(discs)
  338. def forward(self, y, y_hat):
  339. y_d_rs = []
  340. y_d_gs = []
  341. fmap_rs = []
  342. fmap_gs = []
  343. for i, d in enumerate(self.discriminators):
  344. y_d_r, fmap_r = d(y)
  345. y_d_g, fmap_g = d(y_hat)
  346. y_d_rs.append(y_d_r)
  347. y_d_gs.append(y_d_g)
  348. fmap_rs.append(fmap_r)
  349. fmap_gs.append(fmap_g)
  350. return y_d_rs, y_d_gs, fmap_rs, fmap_gs
  351. class SynthesizerTrn(nn.Module):
  352. """
  353. Synthesizer for Training
  354. """
  355. def __init__(
  356. self,
  357. *,
  358. spec_channels,
  359. segment_size,
  360. inter_channels,
  361. hidden_channels,
  362. filter_channels,
  363. n_heads,
  364. n_layers,
  365. kernel_size,
  366. p_dropout,
  367. resblock,
  368. resblock_kernel_sizes,
  369. resblock_dilation_sizes,
  370. upsample_rates,
  371. upsample_initial_channel,
  372. upsample_kernel_sizes,
  373. gin_channels=0,
  374. codebook_size=264,
  375. ):
  376. super().__init__()
  377. self.spec_channels = spec_channels
  378. self.inter_channels = inter_channels
  379. self.hidden_channels = hidden_channels
  380. self.filter_channels = filter_channels
  381. self.n_heads = n_heads
  382. self.n_layers = n_layers
  383. self.kernel_size = kernel_size
  384. self.p_dropout = p_dropout
  385. self.resblock = resblock
  386. self.resblock_kernel_sizes = resblock_kernel_sizes
  387. self.resblock_dilation_sizes = resblock_dilation_sizes
  388. self.upsample_rates = upsample_rates
  389. self.upsample_initial_channel = upsample_initial_channel
  390. self.upsample_kernel_sizes = upsample_kernel_sizes
  391. self.segment_size = segment_size
  392. self.gin_channels = gin_channels
  393. self.enc_p = TextEncoder(
  394. inter_channels,
  395. hidden_channels,
  396. filter_channels,
  397. n_heads,
  398. n_layers,
  399. kernel_size,
  400. p_dropout,
  401. codebook_size=codebook_size,
  402. )
  403. self.dec = Generator(
  404. inter_channels,
  405. resblock,
  406. resblock_kernel_sizes,
  407. resblock_dilation_sizes,
  408. upsample_rates,
  409. upsample_initial_channel,
  410. upsample_kernel_sizes,
  411. gin_channels=gin_channels,
  412. )
  413. self.enc_q = PosteriorEncoder(
  414. spec_channels,
  415. inter_channels,
  416. hidden_channels,
  417. 5,
  418. 1,
  419. 16,
  420. gin_channels=gin_channels,
  421. )
  422. self.flow = ResidualCouplingBlock(
  423. inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
  424. )
  425. self.ref_enc = modules.MelStyleEncoder(
  426. spec_channels, style_vector_dim=gin_channels
  427. )
  428. self.vq = VQEncoder()
  429. for param in self.vq.parameters():
  430. param.requires_grad = False
  431. def forward(
  432. self, audio, audio_lengths, gt_specs, gt_spec_lengths, text, text_lengths
  433. ):
  434. y_mask = torch.unsqueeze(
  435. commons.sequence_mask(gt_spec_lengths, gt_specs.size(2)), 1
  436. ).to(gt_specs.dtype)
  437. ge = self.ref_enc(gt_specs * y_mask, y_mask)
  438. quantized = self.vq(audio, audio_lengths)
  439. quantized = F.interpolate(quantized, size=gt_specs.size(-1), mode="nearest")
  440. x, m_p, logs_p, y_mask = self.enc_p(
  441. quantized, gt_spec_lengths, text, text_lengths, ge
  442. )
  443. z, m_q, logs_q, y_mask = self.enc_q(gt_specs, gt_spec_lengths, g=ge)
  444. z_p = self.flow(z, y_mask, g=ge)
  445. z_slice, ids_slice = commons.rand_slice_segments(
  446. z, gt_spec_lengths, self.segment_size
  447. )
  448. o = self.dec(z_slice, g=ge)
  449. return (
  450. o,
  451. ids_slice,
  452. y_mask,
  453. (z, z_p, m_p, logs_p, m_q, logs_q),
  454. )
  455. @torch.no_grad()
  456. def infer(
  457. self,
  458. audio,
  459. audio_lengths,
  460. gt_specs,
  461. gt_spec_lengths,
  462. text,
  463. text_lengths,
  464. noise_scale=0.5,
  465. ):
  466. y_mask = torch.unsqueeze(
  467. commons.sequence_mask(gt_spec_lengths, gt_specs.size(2)), 1
  468. ).to(gt_specs.dtype)
  469. ge = self.ref_enc(gt_specs * y_mask, y_mask)
  470. quantized = self.vq(audio, audio_lengths)
  471. x, m_p, logs_p, y_mask = self.enc_p(
  472. quantized, audio_lengths, text, text_lengths, ge
  473. )
  474. z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
  475. z = self.flow(z_p, y_mask, g=ge, reverse=True)
  476. o = self.dec(z * y_mask, g=ge)
  477. return o
  478. @torch.no_grad()
  479. def infer_posterior(
  480. self,
  481. gt_specs,
  482. gt_spec_lengths,
  483. ):
  484. y_mask = torch.unsqueeze(
  485. commons.sequence_mask(gt_spec_lengths, gt_specs.size(2)), 1
  486. ).to(gt_specs.dtype)
  487. ge = self.ref_enc(gt_specs * y_mask, y_mask)
  488. z, m_q, logs_q, y_mask = self.enc_q(gt_specs, gt_spec_lengths, g=ge)
  489. o = self.dec(z * y_mask, g=ge)
  490. return o
  491. @torch.no_grad()
  492. def decode(self, codes, text, refer, noise_scale=0.5):
  493. # TODO: not tested yet
  494. ge = None
  495. if refer is not None:
  496. refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
  497. refer_mask = torch.unsqueeze(
  498. commons.sequence_mask(refer_lengths, refer.size(2)), 1
  499. ).to(refer.dtype)
  500. ge = self.ref_enc(refer * refer_mask, refer_mask)
  501. y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
  502. text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
  503. quantized = self.quantizer.decode(codes)
  504. quantized = F.interpolate(
  505. quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
  506. )
  507. x, m_p, logs_p, y_mask = self.enc_p(
  508. quantized, y_lengths, text, text_lengths, ge
  509. )
  510. z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
  511. z = self.flow(z_p, y_mask, g=ge, reverse=True)
  512. o = self.dec((z * y_mask)[:, :, :], g=ge)
  513. return o
  514. if __name__ == "__main__":
  515. import librosa
  516. from transformers import AutoTokenizer
  517. from fish_speech.utils.spectrogram import LinearSpectrogram
  518. model = SynthesizerTrn(
  519. spec_channels=1025,
  520. segment_size=20480 // 640,
  521. inter_channels=192,
  522. hidden_channels=192,
  523. filter_channels=768,
  524. n_heads=2,
  525. n_layers=6,
  526. kernel_size=3,
  527. p_dropout=0.1,
  528. resblock="1",
  529. resblock_kernel_sizes=[3, 7, 11],
  530. resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
  531. upsample_rates=[8, 8, 2, 2, 2],
  532. upsample_initial_channel=512,
  533. upsample_kernel_sizes=[16, 16, 8, 2, 2],
  534. gin_channels=512,
  535. )
  536. ckpt = "checkpoints/Bert-VITS2/G_0.pth"
  537. # Try to load the model
  538. print(f"Loading model from {ckpt}")
  539. checkpoint = torch.load(ckpt, map_location="cpu", weights_only=True)["model"]
  540. d_checkpoint = torch.load(
  541. "checkpoints/Bert-VITS2/D_0.pth", map_location="cpu", weights_only=True
  542. )["model"]
  543. print(checkpoint.keys())
  544. checkpoint.pop("dec.cond.weight")
  545. checkpoint.pop("enc_q.enc.cond_layer.weight_v")
  546. new_checkpoint = {}
  547. for k, v in checkpoint.items():
  548. new_checkpoint["generator." + k] = v
  549. for k, v in d_checkpoint.items():
  550. new_checkpoint["discriminator." + k] = v
  551. torch.save(new_checkpoint, "checkpoints/Bert-VITS2/ensemble.pth")
  552. exit()
  553. print(model.load_state_dict(checkpoint, strict=False))
  554. # Test
  555. ref_audio = librosa.load("data/source/云天河/云天河-旁白/《薄太太》第0025集-yth_24.wav", sr=32000)[
  556. 0
  557. ]
  558. input_audio = librosa.load(
  559. "data/source/云天河/云天河-旁白/《薄太太》第0025集-yth_24.wav", sr=32000
  560. )[0]
  561. ref_audio = input_audio
  562. text = "博兴只知道身边的小女人没睡着,他又凑到她耳边压低了声线。阮苏眉睁眼,不觉得你老公像英雄吗?阮苏还是没反应,这男人是不是有病?刚才那冰冷又强势的样子,和现在这幼稚无赖的样子,根本就判若二人。"
  563. encoded_text = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
  564. spec = LinearSpectrogram(n_fft=2048, hop_length=640, win_length=2048)
  565. ref_audio = torch.tensor(ref_audio).unsqueeze(0).unsqueeze(0)
  566. ref_spec = spec(ref_audio)
  567. input_audio = torch.tensor(input_audio).unsqueeze(0).unsqueeze(0)
  568. text = encoded_text(text, return_tensors="pt")["input_ids"]
  569. print(ref_audio.size(), ref_spec.size(), input_audio.size(), text.size())
  570. o, y_mask, (z, z_p, m_p, logs_p) = model.infer(
  571. input_audio,
  572. torch.LongTensor([input_audio.size(2)]),
  573. ref_spec,
  574. torch.LongTensor([ref_spec.size(2)]),
  575. text,
  576. torch.LongTensor([text.size(1)]),
  577. )
  578. print(o.size(), y_mask.size(), z.size(), z_p.size(), m_p.size(), logs_p.size())
  579. # Save output
  580. import soundfile as sf
  581. sf.write("output.wav", o.squeeze().detach().numpy(), 32000)