|
@@ -1,6 +1,5 @@
|
|
|
import torch
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
import torch.nn.functional as F
|
|
|
-from einx import get_at
|
|
|
|
|
|
|
|
|
|
from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
|
|
from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
|
|
|
from tools.vqgan.extract_vq import get_model
|
|
from tools.vqgan.extract_vq import get_model
|
|
@@ -17,8 +16,11 @@ class Encoder(torch.nn.Module):
|
|
|
def forward(self, audios):
|
|
def forward(self, audios):
|
|
|
mels = self.model.spec_transform(audios)
|
|
mels = self.model.spec_transform(audios)
|
|
|
encoded_features = self.model.backbone(mels)
|
|
encoded_features = self.model.backbone(mels)
|
|
|
- indices = self.model.quantizer.encode(encoded_features)
|
|
|
|
|
- return indices
|
|
|
|
|
|
|
+
|
|
|
|
|
+ z = self.model.quantizer.downsample(encoded_features)
|
|
|
|
|
+ _, indices = self.model.quantizer.residual_fsq(z.transpose(-2, -1))
|
|
|
|
|
+ _, b, l, _ = indices.shape
|
|
|
|
|
+ return indices.permute(1, 0, 3, 2).contiguous().view(b, -1, l)
|
|
|
|
|
|
|
|
|
|
|
|
|
class Decoder(torch.nn.Module):
|
|
class Decoder(torch.nn.Module):
|
|
@@ -30,12 +32,9 @@ class Decoder(torch.nn.Module):
|
|
|
|
|
|
|
|
def get_codes_from_indices(self, cur_index, indices):
|
|
def get_codes_from_indices(self, cur_index, indices):
|
|
|
|
|
|
|
|
- batch_size, quantize_dim, q_dim = indices.shape
|
|
|
|
|
|
|
+ _, quantize_dim, _ = indices.shape
|
|
|
d_dim = self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.shape[2]
|
|
d_dim = self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.shape[2]
|
|
|
|
|
|
|
|
- # because of quantize dropout, one can pass in indices that are coarse
|
|
|
|
|
- # and the network should be able to reconstruct
|
|
|
|
|
-
|
|
|
|
|
if (
|
|
if (
|
|
|
quantize_dim
|
|
quantize_dim
|
|
|
< self.model.quantizer.residual_fsq.rvqs[cur_index].num_quantizers
|
|
< self.model.quantizer.residual_fsq.rvqs[cur_index].num_quantizers
|
|
@@ -53,26 +52,17 @@ class Decoder(torch.nn.Module):
|
|
|
value=-1,
|
|
value=-1,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- # take care of quantizer dropout
|
|
|
|
|
-
|
|
|
|
|
mask = indices == -1
|
|
mask = indices == -1
|
|
|
- indices = indices.masked_fill(
|
|
|
|
|
- mask, 0
|
|
|
|
|
- ) # have it fetch a dummy code to be masked out later
|
|
|
|
|
|
|
+ indices = indices.masked_fill(mask, 0)
|
|
|
|
|
|
|
|
all_codes = torch.gather(
|
|
all_codes = torch.gather(
|
|
|
self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.unsqueeze(1),
|
|
self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.unsqueeze(1),
|
|
|
dim=2,
|
|
dim=2,
|
|
|
- index=indices.long()
|
|
|
|
|
- .permute(2, 0, 1)
|
|
|
|
|
- .unsqueeze(-1)
|
|
|
|
|
- .repeat(1, 1, 1, d_dim), # q, batch_size, frame, dim
|
|
|
|
|
|
|
+ index=indices.long().permute(2, 0, 1).unsqueeze(-1).repeat(1, 1, 1, d_dim),
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
all_codes = all_codes.masked_fill(mask.permute(2, 0, 1).unsqueeze(-1), 0.0)
|
|
all_codes = all_codes.masked_fill(mask.permute(2, 0, 1).unsqueeze(-1), 0.0)
|
|
|
|
|
|
|
|
- # scale the codes
|
|
|
|
|
-
|
|
|
|
|
scales = (
|
|
scales = (
|
|
|
self.model.quantizer.residual_fsq.rvqs[cur_index]
|
|
self.model.quantizer.residual_fsq.rvqs[cur_index]
|
|
|
.scales.unsqueeze(1)
|
|
.scales.unsqueeze(1)
|
|
@@ -80,8 +70,6 @@ class Decoder(torch.nn.Module):
|
|
|
)
|
|
)
|
|
|
all_codes = all_codes * scales
|
|
all_codes = all_codes * scales
|
|
|
|
|
|
|
|
- # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)
|
|
|
|
|
-
|
|
|
|
|
return all_codes
|
|
return all_codes
|
|
|
|
|
|
|
|
def get_output_from_indices(self, cur_index, indices):
|
|
def get_output_from_indices(self, cur_index, indices):
|
|
@@ -112,39 +100,30 @@ class Decoder(torch.nn.Module):
|
|
|
return x
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
-def main():
|
|
|
|
|
- GanModel = get_model(
|
|
|
|
|
- "firefly_gan_vq",
|
|
|
|
|
- "checkpoints/pre/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
|
|
|
|
- device="cpu",
|
|
|
|
|
- )
|
|
|
|
|
|
|
+def main(firefly_gan_vq_path, llama_path, export_prefix):
|
|
|
|
|
+ GanModel = get_model("firefly_gan_vq", firefly_gan_vq_path, device="cpu")
|
|
|
enc = Encoder(GanModel)
|
|
enc = Encoder(GanModel)
|
|
|
dec = Decoder(GanModel)
|
|
dec = Decoder(GanModel)
|
|
|
audio_example = torch.randn(1, 1, 96000)
|
|
audio_example = torch.randn(1, 1, 96000)
|
|
|
indices = enc(audio_example)
|
|
indices = enc(audio_example)
|
|
|
-
|
|
|
|
|
- print(dec(indices).shape)
|
|
|
|
|
-
|
|
|
|
|
- """
|
|
|
|
|
torch.onnx.export(
|
|
torch.onnx.export(
|
|
|
enc,
|
|
enc,
|
|
|
audio_example,
|
|
audio_example,
|
|
|
- "encoder.onnx",
|
|
|
|
|
- dynamic_axes = {
|
|
|
|
|
|
|
+ f"{export_prefix}encoder.onnx",
|
|
|
|
|
+ dynamic_axes={
|
|
|
"audio": [0, 2],
|
|
"audio": [0, 2],
|
|
|
},
|
|
},
|
|
|
do_constant_folding=False,
|
|
do_constant_folding=False,
|
|
|
opset_version=18,
|
|
opset_version=18,
|
|
|
verbose=False,
|
|
verbose=False,
|
|
|
input_names=["audio"],
|
|
input_names=["audio"],
|
|
|
- output_names=["prompt"]
|
|
|
|
|
|
|
+ output_names=["prompt"],
|
|
|
)
|
|
)
|
|
|
- """
|
|
|
|
|
|
|
|
|
|
torch.onnx.export(
|
|
torch.onnx.export(
|
|
|
dec,
|
|
dec,
|
|
|
indices,
|
|
indices,
|
|
|
- "decoder.onnx",
|
|
|
|
|
|
|
+ f"{export_prefix}decoder.onnx",
|
|
|
dynamic_axes={
|
|
dynamic_axes={
|
|
|
"prompt": [0, 2],
|
|
"prompt": [0, 2],
|
|
|
},
|
|
},
|
|
@@ -155,8 +134,6 @@ def main():
|
|
|
output_names=["audio"],
|
|
output_names=["audio"],
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- print(enc(audio_example).shape)
|
|
|
|
|
- print(dec(enc(audio_example)).shape)
|
|
|
|
|
-
|
|
|
|
|
|
|
|
|
|
-main()
|
|
|
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
|
+ main("checkpoints/pre/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", None, "test_")
|