reference_loader.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import io
  2. from hashlib import sha256
  3. from pathlib import Path
  4. from typing import Callable, Literal, Tuple
  5. import torch
  6. import torchaudio
  7. from loguru import logger
  8. from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
  9. from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
  10. from tools.schema import ServeReferenceAudio
  11. class ReferenceLoader:
  12. def __init__(self) -> None:
  13. """
  14. Component of the TTSInferenceEngine class.
  15. Loads and manages the cache for the reference audio and text.
  16. """
  17. self.ref_by_id: dict = {}
  18. self.ref_by_hash: dict = {}
  19. # Make Pylance happy (attribut/method not defined...)
  20. self.decoder_model: FireflyArchitecture
  21. self.encode_reference: Callable
  22. # Define the torchaudio backend
  23. backends = torchaudio.list_audio_backends()
  24. if "ffmpeg" in backends:
  25. self.backend = "ffmpeg"
  26. else:
  27. self.backend = "soundfile"
  28. def load_by_id(
  29. self,
  30. id: str,
  31. use_cache: Literal["on", "off"],
  32. ) -> Tuple:
  33. # Load the references audio and text by id
  34. ref_folder = Path("references") / id
  35. ref_folder.mkdir(parents=True, exist_ok=True)
  36. ref_audios = list_files(
  37. ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
  38. )
  39. if use_cache == "off" or id not in self.ref_by_id:
  40. # If the references are not already loaded, encode them
  41. prompt_tokens = [
  42. self.encode_reference(
  43. # decoder_model=self.decoder_model,
  44. reference_audio=audio_to_bytes(str(ref_audio)),
  45. enable_reference_audio=True,
  46. )
  47. for ref_audio in ref_audios
  48. ]
  49. prompt_texts = [
  50. read_ref_text(str(ref_audio.with_suffix(".lab")))
  51. for ref_audio in ref_audios
  52. ]
  53. self.ref_by_id[id] = (prompt_tokens, prompt_texts)
  54. else:
  55. # Reuse already encoded references
  56. logger.info("Use same references")
  57. prompt_tokens, prompt_texts = self.ref_by_id[id]
  58. return prompt_tokens, prompt_texts
  59. def load_by_hash(
  60. self,
  61. references: list[ServeReferenceAudio],
  62. use_cache: Literal["on", "off"],
  63. ) -> Tuple:
  64. # Load the references audio and text by hash
  65. audio_hashes = [sha256(ref.audio).hexdigest() for ref in references]
  66. cache_used = False
  67. prompt_tokens, prompt_texts = [], []
  68. for i, ref in enumerate(references):
  69. if use_cache == "off" or audio_hashes[i] not in self.ref_by_hash:
  70. # If the references are not already loaded, encode them
  71. prompt_tokens.append(
  72. self.encode_reference(
  73. reference_audio=ref.audio,
  74. enable_reference_audio=True,
  75. )
  76. )
  77. prompt_texts.append(ref.text)
  78. self.ref_by_hash[audio_hashes[i]] = (prompt_tokens, prompt_texts)
  79. else:
  80. # Reuse already encoded references
  81. prompt_tokens, prompt_texts = self.ref_by_hash[audio_hashes[i]]
  82. cache_used = True
  83. if cache_used:
  84. logger.info("Use same references")
  85. return prompt_tokens, prompt_texts
  86. def load_audio(self, reference_audio, sr):
  87. """
  88. Load the audio data from a file or bytes.
  89. """
  90. if len(reference_audio) > 255 or not Path(reference_audio).exists():
  91. audio_data = reference_audio
  92. reference_audio = io.BytesIO(audio_data)
  93. waveform, original_sr = torchaudio.load(reference_audio, backend=self.backend)
  94. if waveform.shape[0] > 1:
  95. waveform = torch.mean(waveform, dim=0, keepdim=True)
  96. if original_sr != sr:
  97. resampler = torchaudio.transforms.Resample(
  98. orig_freq=original_sr, new_freq=sr
  99. )
  100. waveform = resampler(waveform)
  101. audio = waveform.squeeze().numpy()
  102. return audio