Przeglądaj źródła

fix bugs in onnx tracer (#833)

* add onnx export code for vqgan model

add onnx export code for vqgan model

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* make return_complex optional

make return_complex optional

* add vqgan encoder export

add vqgan encoder export

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add test codes

add test codes

* Fix tracer bugs at padding

Fix tracer bugs at padding

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Ναρουσέ·μ·γιουμεμί·Χινακάννα 1 rok temu
rodzic
commit
3d2f842e64
2 zmienionych plików z 24 dodań i 5 usunięć
  1. 9 1
      fish_speech/models/vqgan/modules/firefly.py
  2. 15 4
      tools/export-onnx.py

+ 9 - 1
fish_speech/models/vqgan/modules/firefly.py

@@ -43,7 +43,15 @@ def get_extra_padding_for_conv1d(
     """See `pad_for_conv1d`."""
     length = x.shape[-1]
     n_frames = (length - kernel_size + padding_total) / stride + 1
-    ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
+    # for tracer, math.ceil will make onnx graph become constant
+    if isinstance(n_frames, torch.Tensor):
+        ideal_length = (torch.ceil(n_frames).long() - 1) * stride + (
+            kernel_size - padding_total
+        )
+    else:
+        ideal_length = (math.ceil(n_frames) - 1) * stride + (
+            kernel_size - padding_total
+        )
     return ideal_length - length
 
 

+ 15 - 4
tools/export-onnx.py

@@ -1,3 +1,4 @@
+import onnxruntime
 import torch
 import torch.nn.functional as F
 
@@ -20,7 +21,7 @@ class Encoder(torch.nn.Module):
         z = self.model.quantizer.downsample(encoded_features)
         _, indices = self.model.quantizer.residual_fsq(z.transpose(-2, -1))
         _, b, l, _ = indices.shape
-        return indices.permute(1, 0, 3, 2).contiguous().view(b, -1, l)
+        return indices.permute(1, 0, 3, 2).long().view(b, -1, l)
 
 
 class Decoder(torch.nn.Module):
@@ -58,7 +59,7 @@ class Decoder(torch.nn.Module):
         all_codes = torch.gather(
             self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.unsqueeze(1),
             dim=2,
-            index=indices.long().permute(2, 0, 1).unsqueeze(-1).repeat(1, 1, 1, d_dim),
+            index=indices.permute(2, 0, 1).unsqueeze(-1).repeat(1, 1, 1, d_dim),
         )
 
         all_codes = all_codes.masked_fill(mask.permute(2, 0, 1).unsqueeze(-1), 0.0)
@@ -111,7 +112,7 @@ def main(firefly_gan_vq_path, llama_path, export_prefix):
         audio_example,
         f"{export_prefix}encoder.onnx",
         dynamic_axes={
-            "audio": [0, 2],
+            "audio": {0: "batch_size", 2: "audio_length"},
         },
         do_constant_folding=False,
         opset_version=18,
@@ -125,7 +126,7 @@ def main(firefly_gan_vq_path, llama_path, export_prefix):
         indices,
         f"{export_prefix}decoder.onnx",
         dynamic_axes={
-            "prompt": [0, 2],
+            "prompt": {0: "batch_size", 2: "frame_count"},
         },
         do_constant_folding=False,
         opset_version=18,
@@ -134,6 +135,16 @@ def main(firefly_gan_vq_path, llama_path, export_prefix):
         output_names=["audio"],
     )
 
+    test_example = torch.randn(1, 1, 96000 * 5)
+    encoder_session = onnxruntime.InferenceSession(f"{export_prefix}encoder.onnx")
+    decoder_session = onnxruntime.InferenceSession(f"{export_prefix}decoder.onnx")
+
+    # check graph has no error
+    onnx_enc_out = encoder_session.run(["prompt"], {"audio": test_example.numpy()})[0]
+    torch_enc_out = enc(test_example)
+    onnx_dec_out = decoder_session.run(["audio"], {"prompt": onnx_enc_out})[0]
+    torch_dec_out = dec(torch_enc_out)
+
 
 if __name__ == "__main__":
     main("checkpoints/pre/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", None, "test_")