vq_manager.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. from typing import Callable
  2. import torch
  3. from loguru import logger
  4. from fish_speech.models.dac.modded_dac import DAC
  5. class VQManager:
  6. def __init__(self):
  7. # Make Pylance happy (attribut/method not defined...)
  8. self.decoder_model: DAC
  9. self.load_audio: Callable
  10. def decode_vq_tokens(self, codes):
  11. feature_lengths = torch.tensor(
  12. [codes.shape[1]], device=self.decoder_model.device
  13. )
  14. logger.info(f"VQ features: {codes.shape}")
  15. if isinstance(self.decoder_model, DAC):
  16. return self.decoder_model.decode(
  17. indices=codes[None],
  18. feature_lengths=feature_lengths,
  19. )[0].squeeze()
  20. raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
  21. def encode_reference(self, reference_audio, enable_reference_audio):
  22. if enable_reference_audio and reference_audio is not None:
  23. # Load audios, and prepare basic info here
  24. if hasattr(self.decoder_model, "spec_transform"):
  25. sample_rate = self.decoder_model.spec_transform.sample_rate
  26. else:
  27. sample_rate = self.decoder_model.sample_rate
  28. reference_audio_content = self.load_audio(reference_audio, sample_rate)
  29. audios = torch.from_numpy(reference_audio_content).to(
  30. self.decoder_model.device
  31. )[None, None, :]
  32. audio_lengths = torch.tensor(
  33. [audios.shape[2]], device=self.decoder_model.device, dtype=torch.long
  34. )
  35. logger.info(
  36. f"Loaded audio with {audios.shape[2] / sample_rate:.2f} seconds"
  37. )
  38. # VQ Encoder
  39. if isinstance(self.decoder_model, DAC):
  40. prompt_tokens = self.decoder_model.encode(audios, audio_lengths)[0][0]
  41. logger.info(f"Encoded prompt: {prompt_tokens.shape}")
  42. else:
  43. raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
  44. else:
  45. prompt_tokens = None
  46. logger.info("No reference audio provided")
  47. return prompt_tokens