|
|
@@ -1,3 +1,4 @@
|
|
|
+import onnxruntime
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
@@ -20,7 +21,7 @@ class Encoder(torch.nn.Module):
|
|
|
z = self.model.quantizer.downsample(encoded_features)
|
|
|
_, indices = self.model.quantizer.residual_fsq(z.transpose(-2, -1))
|
|
|
_, b, l, _ = indices.shape
|
|
|
- return indices.permute(1, 0, 3, 2).contiguous().view(b, -1, l)
|
|
|
+ return indices.permute(1, 0, 3, 2).long().view(b, -1, l)
|
|
|
|
|
|
|
|
|
class Decoder(torch.nn.Module):
|
|
|
@@ -58,7 +59,7 @@ class Decoder(torch.nn.Module):
|
|
|
all_codes = torch.gather(
|
|
|
self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.unsqueeze(1),
|
|
|
dim=2,
|
|
|
- index=indices.long().permute(2, 0, 1).unsqueeze(-1).repeat(1, 1, 1, d_dim),
|
|
|
+ index=indices.permute(2, 0, 1).unsqueeze(-1).repeat(1, 1, 1, d_dim),
|
|
|
)
|
|
|
|
|
|
all_codes = all_codes.masked_fill(mask.permute(2, 0, 1).unsqueeze(-1), 0.0)
|
|
|
@@ -111,7 +112,7 @@ def main(firefly_gan_vq_path, llama_path, export_prefix):
|
|
|
audio_example,
|
|
|
f"{export_prefix}encoder.onnx",
|
|
|
dynamic_axes={
|
|
|
- "audio": [0, 2],
|
|
|
+ "audio": {0: "batch_size", 2: "audio_length"},
|
|
|
},
|
|
|
do_constant_folding=False,
|
|
|
opset_version=18,
|
|
|
@@ -125,7 +126,7 @@ def main(firefly_gan_vq_path, llama_path, export_prefix):
|
|
|
indices,
|
|
|
f"{export_prefix}decoder.onnx",
|
|
|
dynamic_axes={
|
|
|
- "prompt": [0, 2],
|
|
|
+ "prompt": {0: "batch_size", 2: "frame_count"},
|
|
|
},
|
|
|
do_constant_folding=False,
|
|
|
opset_version=18,
|
|
|
@@ -134,6 +135,16 @@ def main(firefly_gan_vq_path, llama_path, export_prefix):
|
|
|
output_names=["audio"],
|
|
|
)
|
|
|
|
|
|
+ test_example = torch.randn(1, 1, 96000 * 5)
|
|
|
+ encoder_session = onnxruntime.InferenceSession(f"{export_prefix}encoder.onnx")
|
|
|
+ decoder_session = onnxruntime.InferenceSession(f"{export_prefix}decoder.onnx")
|
|
|
+
|
|
|
+ # check graph has no error
|
|
|
+ onnx_enc_out = encoder_session.run(["prompt"], {"audio": test_example.numpy()})[0]
|
|
|
+ torch_enc_out = enc(test_example)
|
|
|
+ onnx_dec_out = decoder_session.run(["audio"], {"prompt": onnx_enc_out})[0]
|
|
|
+ torch_dec_out = dec(torch_enc_out)
|
|
|
+
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main("checkpoints/pre/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", None, "test_")
|