| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- import torch
- import torch.nn.functional as F
- from einx import get_at
- from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
- from tools.vqgan.extract_vq import get_model
- PAD_TOKEN_ID = torch.LongTensor([CODEBOOK_PAD_TOKEN_ID])
- class Encoder(torch.nn.Module):
- def __init__(self, model):
- super().__init__()
- self.model = model
- self.model.spec_transform.spectrogram.return_complex = False
- def forward(self, audios):
- mels = self.model.spec_transform(audios)
- encoded_features = self.model.backbone(mels)
- indices = self.model.quantizer.encode(encoded_features)
- return indices
- class Decoder(torch.nn.Module):
- def __init__(self, model):
- super().__init__()
- self.model = model
- self.model.head.training = False
- self.model.head.checkpointing = False
- def get_codes_from_indices(self, cur_index, indices):
- batch_size, quantize_dim, q_dim = indices.shape
- d_dim = self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.shape[2]
- # because of quantize dropout, one can pass in indices that are coarse
- # and the network should be able to reconstruct
- if (
- quantize_dim
- < self.model.quantizer.residual_fsq.rvqs[cur_index].num_quantizers
- ):
- assert (
- self.model.quantizer.residual_fsq.rvqs[cur_index].quantize_dropout > 0.0
- ), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
- indices = F.pad(
- indices,
- (
- 0,
- self.model.quantizer.residual_fsq.rvqs[cur_index].num_quantizers
- - quantize_dim,
- ),
- value=-1,
- )
- # take care of quantizer dropout
- mask = indices == -1
- indices = indices.masked_fill(
- mask, 0
- ) # have it fetch a dummy code to be masked out later
- all_codes = torch.gather(
- self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.unsqueeze(1),
- dim=2,
- index=indices.long()
- .permute(2, 0, 1)
- .unsqueeze(-1)
- .repeat(1, 1, 1, d_dim), # q, batch_size, frame, dim
- )
- all_codes = all_codes.masked_fill(mask.permute(2, 0, 1).unsqueeze(-1), 0.0)
- # scale the codes
- scales = (
- self.model.quantizer.residual_fsq.rvqs[cur_index]
- .scales.unsqueeze(1)
- .unsqueeze(1)
- )
- all_codes = all_codes * scales
- # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)
- return all_codes
- def get_output_from_indices(self, cur_index, indices):
- codes = self.get_codes_from_indices(cur_index, indices)
- codes_summed = codes.sum(dim=0)
- return self.model.quantizer.residual_fsq.rvqs[cur_index].project_out(
- codes_summed
- )
- def forward(self, indices) -> torch.Tensor:
- batch_size, _, length = indices.shape
- dims = self.model.quantizer.residual_fsq.dim
- groups = self.model.quantizer.residual_fsq.groups
- dim_per_group = dims // groups
- # indices = rearrange(indices, "b (g r) l -> g b l r", g=groups)
- indices = indices.view(batch_size, groups, -1, length).permute(1, 0, 3, 2)
- # z_q = self.model.quantizer.residual_fsq.get_output_from_indices(indices)
- z_q = torch.empty((batch_size, length, dims))
- for i in range(groups):
- z_q[:, :, i * dim_per_group : (i + 1) * dim_per_group] = (
- self.get_output_from_indices(i, indices[i])
- )
- z = self.model.quantizer.upsample(z_q.transpose(1, 2))
- x = self.model.head(z)
- return x
- def main():
- GanModel = get_model(
- "firefly_gan_vq",
- "checkpoints/pre/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
- device="cpu",
- )
- enc = Encoder(GanModel)
- dec = Decoder(GanModel)
- audio_example = torch.randn(1, 1, 96000)
- indices = enc(audio_example)
- print(dec(indices).shape)
- """
- torch.onnx.export(
- enc,
- audio_example,
- "encoder.onnx",
- dynamic_axes = {
- "audio": [0, 2],
- },
- do_constant_folding=False,
- opset_version=18,
- verbose=False,
- input_names=["audio"],
- output_names=["prompt"]
- )
- """
- torch.onnx.export(
- dec,
- indices,
- "decoder.onnx",
- dynamic_axes={
- "prompt": [0, 2],
- },
- do_constant_folding=False,
- opset_version=18,
- verbose=False,
- input_names=["prompt"],
- output_names=["audio"],
- )
- print(enc(audio_example).shape)
- print(dec(enc(audio_example)).shape)
- main()
|