export-onnx.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. import torch
  2. import torch.nn.functional as F
  3. from einx import get_at
  4. from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
  5. from tools.vqgan.extract_vq import get_model
  6. PAD_TOKEN_ID = torch.LongTensor([CODEBOOK_PAD_TOKEN_ID])
  7. class Encoder(torch.nn.Module):
  8. def __init__(self, model):
  9. super().__init__()
  10. self.model = model
  11. self.model.spec_transform.spectrogram.return_complex = False
  12. def forward(self, audios):
  13. mels = self.model.spec_transform(audios)
  14. encoded_features = self.model.backbone(mels)
  15. indices = self.model.quantizer.encode(encoded_features)
  16. return indices
  17. class Decoder(torch.nn.Module):
  18. def __init__(self, model):
  19. super().__init__()
  20. self.model = model
  21. self.model.head.training = False
  22. self.model.head.checkpointing = False
  23. def get_codes_from_indices(self, cur_index, indices):
  24. batch_size, quantize_dim, q_dim = indices.shape
  25. d_dim = self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.shape[2]
  26. # because of quantize dropout, one can pass in indices that are coarse
  27. # and the network should be able to reconstruct
  28. if (
  29. quantize_dim
  30. < self.model.quantizer.residual_fsq.rvqs[cur_index].num_quantizers
  31. ):
  32. assert (
  33. self.model.quantizer.residual_fsq.rvqs[cur_index].quantize_dropout > 0.0
  34. ), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
  35. indices = F.pad(
  36. indices,
  37. (
  38. 0,
  39. self.model.quantizer.residual_fsq.rvqs[cur_index].num_quantizers
  40. - quantize_dim,
  41. ),
  42. value=-1,
  43. )
  44. # take care of quantizer dropout
  45. mask = indices == -1
  46. indices = indices.masked_fill(
  47. mask, 0
  48. ) # have it fetch a dummy code to be masked out later
  49. all_codes = torch.gather(
  50. self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.unsqueeze(1),
  51. dim=2,
  52. index=indices.long()
  53. .permute(2, 0, 1)
  54. .unsqueeze(-1)
  55. .repeat(1, 1, 1, d_dim), # q, batch_size, frame, dim
  56. )
  57. all_codes = all_codes.masked_fill(mask.permute(2, 0, 1).unsqueeze(-1), 0.0)
  58. # scale the codes
  59. scales = (
  60. self.model.quantizer.residual_fsq.rvqs[cur_index]
  61. .scales.unsqueeze(1)
  62. .unsqueeze(1)
  63. )
  64. all_codes = all_codes * scales
  65. # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)
  66. return all_codes
  67. def get_output_from_indices(self, cur_index, indices):
  68. codes = self.get_codes_from_indices(cur_index, indices)
  69. codes_summed = codes.sum(dim=0)
  70. return self.model.quantizer.residual_fsq.rvqs[cur_index].project_out(
  71. codes_summed
  72. )
  73. def forward(self, indices) -> torch.Tensor:
  74. batch_size, _, length = indices.shape
  75. dims = self.model.quantizer.residual_fsq.dim
  76. groups = self.model.quantizer.residual_fsq.groups
  77. dim_per_group = dims // groups
  78. # indices = rearrange(indices, "b (g r) l -> g b l r", g=groups)
  79. indices = indices.view(batch_size, groups, -1, length).permute(1, 0, 3, 2)
  80. # z_q = self.model.quantizer.residual_fsq.get_output_from_indices(indices)
  81. z_q = torch.empty((batch_size, length, dims))
  82. for i in range(groups):
  83. z_q[:, :, i * dim_per_group : (i + 1) * dim_per_group] = (
  84. self.get_output_from_indices(i, indices[i])
  85. )
  86. z = self.model.quantizer.upsample(z_q.transpose(1, 2))
  87. x = self.model.head(z)
  88. return x
  89. def main():
  90. GanModel = get_model(
  91. "firefly_gan_vq",
  92. "checkpoints/pre/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
  93. device="cpu",
  94. )
  95. enc = Encoder(GanModel)
  96. dec = Decoder(GanModel)
  97. audio_example = torch.randn(1, 1, 96000)
  98. indices = enc(audio_example)
  99. print(dec(indices).shape)
  100. """
  101. torch.onnx.export(
  102. enc,
  103. audio_example,
  104. "encoder.onnx",
  105. dynamic_axes = {
  106. "audio": [0, 2],
  107. },
  108. do_constant_folding=False,
  109. opset_version=18,
  110. verbose=False,
  111. input_names=["audio"],
  112. output_names=["prompt"]
  113. )
  114. """
  115. torch.onnx.export(
  116. dec,
  117. indices,
  118. "decoder.onnx",
  119. dynamic_axes={
  120. "prompt": [0, 2],
  121. },
  122. do_constant_folding=False,
  123. opset_version=18,
  124. verbose=False,
  125. input_names=["prompt"],
  126. output_names=["audio"],
  127. )
  128. print(enc(audio_example).shape)
  129. print(dec(enc(audio_example)).shape)
  130. main()