export_onnx.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import onnxruntime
  2. import torch
  3. import torch.nn.functional as F
  4. from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
  5. from tools.vqgan.extract_vq import get_model
  6. PAD_TOKEN_ID = torch.LongTensor([CODEBOOK_PAD_TOKEN_ID])
  7. class Encoder(torch.nn.Module):
  8. def __init__(self, model):
  9. super().__init__()
  10. self.model = model
  11. self.model.spec_transform.spectrogram.return_complex = False
  12. def forward(self, audios):
  13. mels = self.model.spec_transform(audios)
  14. encoded_features = self.model.backbone(mels)
  15. z = self.model.quantizer.downsample(encoded_features)
  16. _, indices = self.model.quantizer.residual_fsq(z.transpose(-2, -1))
  17. _, b, l, _ = indices.shape
  18. return indices.permute(1, 0, 3, 2).long().view(b, -1, l)
  19. class Decoder(torch.nn.Module):
  20. def __init__(self, model):
  21. super().__init__()
  22. self.model = model
  23. self.model.head.training = False
  24. self.model.head.checkpointing = False
  25. def get_codes_from_indices(self, cur_index, indices):
  26. _, quantize_dim, _ = indices.shape
  27. d_dim = self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.shape[2]
  28. if (
  29. quantize_dim
  30. < self.model.quantizer.residual_fsq.rvqs[cur_index].num_quantizers
  31. ):
  32. assert (
  33. self.model.quantizer.residual_fsq.rvqs[cur_index].quantize_dropout > 0.0
  34. ), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
  35. indices = F.pad(
  36. indices,
  37. (
  38. 0,
  39. self.model.quantizer.residual_fsq.rvqs[cur_index].num_quantizers
  40. - quantize_dim,
  41. ),
  42. value=-1,
  43. )
  44. mask = indices == -1
  45. indices = indices.masked_fill(mask, 0)
  46. all_codes = torch.gather(
  47. self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.unsqueeze(1),
  48. dim=2,
  49. index=indices.permute(2, 0, 1).unsqueeze(-1).repeat(1, 1, 1, d_dim),
  50. )
  51. all_codes = all_codes.masked_fill(mask.permute(2, 0, 1).unsqueeze(-1), 0.0)
  52. scales = (
  53. self.model.quantizer.residual_fsq.rvqs[cur_index]
  54. .scales.unsqueeze(1)
  55. .unsqueeze(1)
  56. )
  57. all_codes = all_codes * scales
  58. return all_codes
  59. def get_output_from_indices(self, cur_index, indices):
  60. codes = self.get_codes_from_indices(cur_index, indices)
  61. codes_summed = codes.sum(dim=0)
  62. return self.model.quantizer.residual_fsq.rvqs[cur_index].project_out(
  63. codes_summed
  64. )
  65. def forward(self, indices) -> torch.Tensor:
  66. batch_size, _, length = indices.shape
  67. dims = self.model.quantizer.residual_fsq.dim
  68. groups = self.model.quantizer.residual_fsq.groups
  69. dim_per_group = dims // groups
  70. # indices = rearrange(indices, "b (g r) l -> g b l r", g=groups)
  71. indices = indices.view(batch_size, groups, -1, length).permute(1, 0, 3, 2)
  72. # z_q = self.model.quantizer.residual_fsq.get_output_from_indices(indices)
  73. z_q = torch.empty((batch_size, length, dims))
  74. for i in range(groups):
  75. z_q[:, :, i * dim_per_group : (i + 1) * dim_per_group] = (
  76. self.get_output_from_indices(i, indices[i])
  77. )
  78. z = self.model.quantizer.upsample(z_q.transpose(1, 2))
  79. x = self.model.head(z)
  80. return x
  81. def main(firefly_gan_vq_path, llama_path, export_prefix):
  82. GanModel = get_model("firefly_gan_vq", firefly_gan_vq_path, device="cpu")
  83. enc = Encoder(GanModel)
  84. dec = Decoder(GanModel)
  85. audio_example = torch.randn(1, 1, 96000)
  86. indices = enc(audio_example)
  87. torch.onnx.export(
  88. enc,
  89. audio_example,
  90. f"{export_prefix}encoder.onnx",
  91. dynamic_axes={
  92. "audio": {0: "batch_size", 2: "audio_length"},
  93. },
  94. do_constant_folding=False,
  95. opset_version=18,
  96. verbose=False,
  97. input_names=["audio"],
  98. output_names=["prompt"],
  99. )
  100. torch.onnx.export(
  101. dec,
  102. indices,
  103. f"{export_prefix}decoder.onnx",
  104. dynamic_axes={
  105. "prompt": {0: "batch_size", 2: "frame_count"},
  106. },
  107. do_constant_folding=False,
  108. opset_version=18,
  109. verbose=False,
  110. input_names=["prompt"],
  111. output_names=["audio"],
  112. )
  113. test_example = torch.randn(1, 1, 96000 * 5)
  114. encoder_session = onnxruntime.InferenceSession(f"{export_prefix}encoder.onnx")
  115. decoder_session = onnxruntime.InferenceSession(f"{export_prefix}decoder.onnx")
  116. # check graph has no error
  117. onnx_enc_out = encoder_session.run(["prompt"], {"audio": test_example.numpy()})[0]
  118. torch_enc_out = enc(test_example)
  119. onnx_dec_out = decoder_session.run(["audio"], {"prompt": onnx_enc_out})[0]
  120. torch_dec_out = dec(torch_enc_out)
  121. if __name__ == "__main__":
  122. main("checkpoints/pre/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", None, "test_")