Bläddra i källkod

add onnx export code for vqgan model (#830)

* add onnx export code for vqgan model

add onnx export code for vqgan model

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Ναρουσέ·μ·γιουμεμί·Χινακάννα 1 år sedan
förälder
incheckning
18965dea47
1 ändrade filer med 162 tillägg och 0 borttagningar
  1. 162 0
      tools/export-onnx.py

+ 162 - 0
tools/export-onnx.py

@@ -0,0 +1,162 @@
+import torch
+import torch.nn.functional as F
+from einx import get_at
+
+from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
+from tools.vqgan.extract_vq import get_model
+
+PAD_TOKEN_ID = torch.LongTensor([CODEBOOK_PAD_TOKEN_ID])
+
+
+class Encoder(torch.nn.Module):
+    def __init__(self, model):
+        super().__init__()
+        self.model = model
+        self.model.spec_transform.spectrogram.return_complex = False
+
+    def forward(self, audios):
+        mels = self.model.spec_transform(audios)
+        encoded_features = self.model.backbone(mels)
+        indices = self.model.quantizer.encode(encoded_features)
+        return indices
+
+
+class Decoder(torch.nn.Module):
+    def __init__(self, model):
+        super().__init__()
+        self.model = model
+        self.model.head.training = False
+        self.model.head.checkpointing = False
+
+    def get_codes_from_indices(self, cur_index, indices):
+
+        batch_size, quantize_dim, q_dim = indices.shape
+        d_dim = self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.shape[2]
+
+        # because of quantize dropout, one can pass in indices that are coarse
+        # and the network should be able to reconstruct
+
+        if (
+            quantize_dim
+            < self.model.quantizer.residual_fsq.rvqs[cur_index].num_quantizers
+        ):
+            assert (
+                self.model.quantizer.residual_fsq.rvqs[cur_index].quantize_dropout > 0.0
+            ), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
+            indices = F.pad(
+                indices,
+                (
+                    0,
+                    self.model.quantizer.residual_fsq.rvqs[cur_index].num_quantizers
+                    - quantize_dim,
+                ),
+                value=-1,
+            )
+
+        # take care of quantizer dropout
+
+        mask = indices == -1
+        indices = indices.masked_fill(
+            mask, 0
+        )  # have it fetch a dummy code to be masked out later
+
+        all_codes = torch.gather(
+            self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.unsqueeze(1),
+            dim=2,
+            index=indices.long()
+            .permute(2, 0, 1)
+            .unsqueeze(-1)
+            .repeat(1, 1, 1, d_dim),  # q, batch_size, frame, dim
+        )
+
+        all_codes = all_codes.masked_fill(mask.permute(2, 0, 1).unsqueeze(-1), 0.0)
+
+        # scale the codes
+
+        scales = (
+            self.model.quantizer.residual_fsq.rvqs[cur_index]
+            .scales.unsqueeze(1)
+            .unsqueeze(1)
+        )
+        all_codes = all_codes * scales
+
+        # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)
+
+        return all_codes
+
+    def get_output_from_indices(self, cur_index, indices):
+        codes = self.get_codes_from_indices(cur_index, indices)
+        codes_summed = codes.sum(dim=0)
+        return self.model.quantizer.residual_fsq.rvqs[cur_index].project_out(
+            codes_summed
+        )
+
+    def forward(self, indices) -> torch.Tensor:
+        batch_size, _, length = indices.shape
+        dims = self.model.quantizer.residual_fsq.dim
+        groups = self.model.quantizer.residual_fsq.groups
+        dim_per_group = dims // groups
+
+        # indices = rearrange(indices, "b (g r) l -> g b l r", g=groups)
+        indices = indices.view(batch_size, groups, -1, length).permute(1, 0, 3, 2)
+
+        # z_q = self.model.quantizer.residual_fsq.get_output_from_indices(indices)
+        z_q = torch.empty((batch_size, length, dims))
+        for i in range(groups):
+            z_q[:, :, i * dim_per_group : (i + 1) * dim_per_group] = (
+                self.get_output_from_indices(i, indices[i])
+            )
+
+        z = self.model.quantizer.upsample(z_q.transpose(1, 2))
+        x = self.model.head(z)
+        return x
+
+
+def main():
+    GanModel = get_model(
+        "firefly_gan_vq",
+        "checkpoints/pre/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+        device="cpu",
+    )
+    enc = Encoder(GanModel)
+    dec = Decoder(GanModel)
+    audio_example = torch.randn(1, 1, 96000)
+    indices = enc(audio_example)
+
+    print(dec(indices).shape)
+
+    """
+    torch.onnx.export(
+        enc,
+        audio_example,
+        "encoder.onnx",
+        dynamic_axes = {
+            "audio": [0, 2],
+        },
+        do_constant_folding=False,
+        opset_version=18,
+        verbose=False,
+        input_names=["audio"],
+        output_names=["prompt"]
+    )
+    """
+
+    torch.onnx.export(
+        dec,
+        indices,
+        "decoder.onnx",
+        dynamic_axes={
+            "prompt": [0, 2],
+        },
+        do_constant_folding=False,
+        opset_version=18,
+        verbose=False,
+        input_names=["prompt"],
+        output_names=["audio"],
+    )
+
+    print(enc(audio_example).shape)
+    print(dec(enc(audio_example)).shape)
+
+
+main()