فهرست منبع

add onnx export code for vqgan encoder (#831)

* add onnx export code for vqgan model

add onnx export code for vqgan model

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* make return_complex optional

make return_complex optional

* add vqgan encoder export

add vqgan encoder export

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Ναρουσέ·μ·γιουμεμί·Χινακάννα 1 سال پیش
والد
کامیت
e908d40b90
2فایلهای تغییر یافته به همراه20 افزوده شده و 41 حذف شده
  1. 4 2
      fish_speech/utils/spectrogram.py
  2. 16 39
      tools/export-onnx.py

+ 4 - 2
fish_speech/utils/spectrogram.py

@@ -20,6 +20,7 @@ class LinearSpectrogram(nn.Module):
         self.hop_length = hop_length
         self.hop_length = hop_length
         self.center = center
         self.center = center
         self.mode = mode
         self.mode = mode
+        self.return_complex = True
 
 
         self.register_buffer("window", torch.hann_window(win_length), persistent=False)
         self.register_buffer("window", torch.hann_window(win_length), persistent=False)
 
 
@@ -46,10 +47,11 @@ class LinearSpectrogram(nn.Module):
             pad_mode="reflect",
             pad_mode="reflect",
             normalized=False,
             normalized=False,
             onesided=True,
             onesided=True,
-            return_complex=True,
+            return_complex=self.return_complex,
         )
         )
 
 
-        spec = torch.view_as_real(spec)
+        if self.return_complex:
+            spec = torch.view_as_real(spec)
 
 
         if self.mode == "pow2_sqrt":
         if self.mode == "pow2_sqrt":
             spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
             spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)

+ 16 - 39
tools/export-onnx.py

@@ -1,6 +1,5 @@
 import torch
 import torch
 import torch.nn.functional as F
 import torch.nn.functional as F
-from einx import get_at
 
 
 from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
 from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
 from tools.vqgan.extract_vq import get_model
 from tools.vqgan.extract_vq import get_model
@@ -17,8 +16,11 @@ class Encoder(torch.nn.Module):
     def forward(self, audios):
     def forward(self, audios):
         mels = self.model.spec_transform(audios)
         mels = self.model.spec_transform(audios)
         encoded_features = self.model.backbone(mels)
         encoded_features = self.model.backbone(mels)
-        indices = self.model.quantizer.encode(encoded_features)
-        return indices
+
+        z = self.model.quantizer.downsample(encoded_features)
+        _, indices = self.model.quantizer.residual_fsq(z.transpose(-2, -1))
+        _, b, l, _ = indices.shape
+        return indices.permute(1, 0, 3, 2).contiguous().view(b, -1, l)
 
 
 
 
 class Decoder(torch.nn.Module):
 class Decoder(torch.nn.Module):
@@ -30,12 +32,9 @@ class Decoder(torch.nn.Module):
 
 
     def get_codes_from_indices(self, cur_index, indices):
     def get_codes_from_indices(self, cur_index, indices):
 
 
-        batch_size, quantize_dim, q_dim = indices.shape
+        _, quantize_dim, _ = indices.shape
         d_dim = self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.shape[2]
         d_dim = self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.shape[2]
 
 
-        # because of quantize dropout, one can pass in indices that are coarse
-        # and the network should be able to reconstruct
-
         if (
         if (
             quantize_dim
             quantize_dim
             < self.model.quantizer.residual_fsq.rvqs[cur_index].num_quantizers
             < self.model.quantizer.residual_fsq.rvqs[cur_index].num_quantizers
@@ -53,26 +52,17 @@ class Decoder(torch.nn.Module):
                 value=-1,
                 value=-1,
             )
             )
 
 
-        # take care of quantizer dropout
-
         mask = indices == -1
         mask = indices == -1
-        indices = indices.masked_fill(
-            mask, 0
-        )  # have it fetch a dummy code to be masked out later
+        indices = indices.masked_fill(mask, 0)
 
 
         all_codes = torch.gather(
         all_codes = torch.gather(
             self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.unsqueeze(1),
             self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.unsqueeze(1),
             dim=2,
             dim=2,
-            index=indices.long()
-            .permute(2, 0, 1)
-            .unsqueeze(-1)
-            .repeat(1, 1, 1, d_dim),  # q, batch_size, frame, dim
+            index=indices.long().permute(2, 0, 1).unsqueeze(-1).repeat(1, 1, 1, d_dim),
         )
         )
 
 
         all_codes = all_codes.masked_fill(mask.permute(2, 0, 1).unsqueeze(-1), 0.0)
         all_codes = all_codes.masked_fill(mask.permute(2, 0, 1).unsqueeze(-1), 0.0)
 
 
-        # scale the codes
-
         scales = (
         scales = (
             self.model.quantizer.residual_fsq.rvqs[cur_index]
             self.model.quantizer.residual_fsq.rvqs[cur_index]
             .scales.unsqueeze(1)
             .scales.unsqueeze(1)
@@ -80,8 +70,6 @@ class Decoder(torch.nn.Module):
         )
         )
         all_codes = all_codes * scales
         all_codes = all_codes * scales
 
 
-        # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)
-
         return all_codes
         return all_codes
 
 
     def get_output_from_indices(self, cur_index, indices):
     def get_output_from_indices(self, cur_index, indices):
@@ -112,39 +100,30 @@ class Decoder(torch.nn.Module):
         return x
         return x
 
 
 
 
-def main():
-    GanModel = get_model(
-        "firefly_gan_vq",
-        "checkpoints/pre/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
-        device="cpu",
-    )
+def main(firefly_gan_vq_path, llama_path, export_prefix):
+    GanModel = get_model("firefly_gan_vq", firefly_gan_vq_path, device="cpu")
     enc = Encoder(GanModel)
     enc = Encoder(GanModel)
     dec = Decoder(GanModel)
     dec = Decoder(GanModel)
     audio_example = torch.randn(1, 1, 96000)
     audio_example = torch.randn(1, 1, 96000)
     indices = enc(audio_example)
     indices = enc(audio_example)
-
-    print(dec(indices).shape)
-
-    """
     torch.onnx.export(
     torch.onnx.export(
         enc,
         enc,
         audio_example,
         audio_example,
-        "encoder.onnx",
-        dynamic_axes = {
+        f"{export_prefix}encoder.onnx",
+        dynamic_axes={
             "audio": [0, 2],
             "audio": [0, 2],
         },
         },
         do_constant_folding=False,
         do_constant_folding=False,
         opset_version=18,
         opset_version=18,
         verbose=False,
         verbose=False,
         input_names=["audio"],
         input_names=["audio"],
-        output_names=["prompt"]
+        output_names=["prompt"],
     )
     )
-    """
 
 
     torch.onnx.export(
     torch.onnx.export(
         dec,
         dec,
         indices,
         indices,
-        "decoder.onnx",
+        f"{export_prefix}decoder.onnx",
         dynamic_axes={
         dynamic_axes={
             "prompt": [0, 2],
             "prompt": [0, 2],
         },
         },
@@ -155,8 +134,6 @@ def main():
         output_names=["audio"],
         output_names=["audio"],
     )
     )
 
 
-    print(enc(audio_example).shape)
-    print(dec(enc(audio_example)).shape)
-
 
 
-main()
+if __name__ == "__main__":
+    main("checkpoints/pre/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", None, "test_")