vq_manager.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. from typing import Callable
  2. import torch
  3. from loguru import logger
  4. from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
  5. class VQManager:
  6. def __init__(self):
  7. # Make Pylance happy (attribut/method not defined...)
  8. self.decoder_model: FireflyArchitecture
  9. self.load_audio: Callable
  10. def decode_vq_tokens(self, codes):
  11. feature_lengths = torch.tensor(
  12. [codes.shape[1]], device=self.decoder_model.device
  13. )
  14. logger.info(f"VQ features: {codes.shape}")
  15. if isinstance(self.decoder_model, FireflyArchitecture):
  16. return self.decoder_model.decode(
  17. indices=codes[None],
  18. feature_lengths=feature_lengths,
  19. )[0].squeeze()
  20. raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
  21. def encode_reference(self, reference_audio, enable_reference_audio):
  22. if enable_reference_audio and reference_audio is not None:
  23. # Load audios, and prepare basic info here
  24. reference_audio_content = self.load_audio(
  25. reference_audio, self.decoder_model.spec_transform.sample_rate
  26. )
  27. audios = torch.from_numpy(reference_audio_content).to(
  28. self.decoder_model.device
  29. )[None, None, :]
  30. audio_lengths = torch.tensor(
  31. [audios.shape[2]], device=self.decoder_model.device, dtype=torch.long
  32. )
  33. logger.info(
  34. f"Loaded audio with {audios.shape[2] / self.decoder_model.spec_transform.sample_rate:.2f} seconds"
  35. )
  36. # VQ Encoder
  37. if isinstance(self.decoder_model, FireflyArchitecture):
  38. prompt_tokens = self.decoder_model.encode(audios, audio_lengths)[0][0]
  39. logger.info(f"Encoded prompt: {prompt_tokens.shape}")
  40. else:
  41. raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
  42. else:
  43. prompt_tokens = None
  44. logger.info("No reference audio provided")
  45. return prompt_tokens