vq_manager.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. from typing import Callable
  2. import torch
  3. from loguru import logger
  4. from fish_speech.models.dac.modded_dac import DAC
  5. class VQManager:
  6. def __init__(self):
  7. # Make Pylance happy (attribut/method not defined...)
  8. self.decoder_model: DAC
  9. self.load_audio: Callable
  10. def decode_vq_tokens(self, codes):
  11. logger.info(f"VQ features: {codes.shape}")
  12. if isinstance(self.decoder_model, DAC):
  13. return self.decoder_model.from_indices(codes[None])[0].squeeze()
  14. raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
  15. def encode_reference(self, reference_audio, enable_reference_audio):
  16. if enable_reference_audio and reference_audio is not None:
  17. # Load audios, and prepare basic info here
  18. if hasattr(self.decoder_model, "spec_transform"):
  19. sample_rate = self.decoder_model.spec_transform.sample_rate
  20. else:
  21. sample_rate = self.decoder_model.sample_rate
  22. reference_audio_content = self.load_audio(reference_audio, sample_rate)
  23. audios = torch.from_numpy(reference_audio_content).to(
  24. self.decoder_model.device
  25. )[None, None, :]
  26. audio_lengths = torch.tensor(
  27. [audios.shape[2]], device=self.decoder_model.device, dtype=torch.long
  28. )
  29. logger.info(
  30. f"Loaded audio with {audios.shape[2] / sample_rate:.2f} seconds"
  31. )
  32. # VQ Encoder
  33. if isinstance(self.decoder_model, DAC):
  34. prompt_tokens = self.decoder_model.encode(audios, audio_lengths)[0][0]
  35. logger.info(f"Encoded prompt: {prompt_tokens.shape}")
  36. else:
  37. raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
  38. else:
  39. prompt_tokens = None
  40. logger.info("No reference audio provided")
  41. return prompt_tokens