reference_loader.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import io
  2. from hashlib import sha256
  3. from pathlib import Path
  4. from typing import Callable, Literal, Tuple
  5. import torch
  6. import torchaudio
  7. from loguru import logger
  8. from fish_speech.models.dac.modded_dac import DAC
  9. from fish_speech.utils.file import (
  10. AUDIO_EXTENSIONS,
  11. audio_to_bytes,
  12. list_files,
  13. read_ref_text,
  14. )
  15. from fish_speech.utils.schema import ServeReferenceAudio
  16. class ReferenceLoader:
  17. def __init__(self) -> None:
  18. """
  19. Component of the TTSInferenceEngine class.
  20. Loads and manages the cache for the reference audio and text.
  21. """
  22. self.ref_by_id: dict = {}
  23. self.ref_by_hash: dict = {}
  24. # Make Pylance happy (attribut/method not defined...)
  25. self.decoder_model: DAC
  26. self.encode_reference: Callable
  27. # Define the torchaudio backend
  28. backends = torchaudio.list_audio_backends()
  29. if "ffmpeg" in backends:
  30. self.backend = "ffmpeg"
  31. else:
  32. self.backend = "soundfile"
  33. def load_by_id(
  34. self,
  35. id: str,
  36. use_cache: Literal["on", "off"],
  37. ) -> Tuple:
  38. # Load the references audio and text by id
  39. ref_folder = Path("references") / id
  40. ref_folder.mkdir(parents=True, exist_ok=True)
  41. ref_audios = list_files(
  42. ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
  43. )
  44. if use_cache == "off" or id not in self.ref_by_id:
  45. # If the references are not already loaded, encode them
  46. prompt_tokens = [
  47. self.encode_reference(
  48. # decoder_model=self.decoder_model,
  49. reference_audio=audio_to_bytes(str(ref_audio)),
  50. enable_reference_audio=True,
  51. )
  52. for ref_audio in ref_audios
  53. ]
  54. prompt_texts = [
  55. read_ref_text(str(ref_audio.with_suffix(".lab")))
  56. for ref_audio in ref_audios
  57. ]
  58. self.ref_by_id[id] = (prompt_tokens, prompt_texts)
  59. else:
  60. # Reuse already encoded references
  61. logger.info("Use same references")
  62. prompt_tokens, prompt_texts = self.ref_by_id[id]
  63. return prompt_tokens, prompt_texts
  64. def load_by_hash(
  65. self,
  66. references: list[ServeReferenceAudio],
  67. use_cache: Literal["on", "off"],
  68. ) -> Tuple:
  69. # Load the references audio and text by hash
  70. audio_hashes = [sha256(ref.audio).hexdigest() for ref in references]
  71. cache_used = False
  72. prompt_tokens, prompt_texts = [], []
  73. for i, ref in enumerate(references):
  74. if use_cache == "off" or audio_hashes[i] not in self.ref_by_hash:
  75. # If the references are not already loaded, encode them
  76. prompt_tokens.append(
  77. self.encode_reference(
  78. reference_audio=ref.audio,
  79. enable_reference_audio=True,
  80. )
  81. )
  82. prompt_texts.append(ref.text)
  83. self.ref_by_hash[audio_hashes[i]] = (prompt_tokens[-1], ref.text)
  84. else:
  85. # Reuse already encoded references
  86. cached_token, cached_text = self.ref_by_hash[audio_hashes[i]]
  87. prompt_tokens.append(cached_token)
  88. prompt_texts.append(cached_text)
  89. cache_used = True
  90. if cache_used:
  91. logger.info("Use same references")
  92. return prompt_tokens, prompt_texts
  93. def load_audio(self, reference_audio, sr):
  94. """
  95. Load the audio data from a file or bytes.
  96. """
  97. if len(reference_audio) > 255 or not Path(reference_audio).exists():
  98. audio_data = reference_audio
  99. reference_audio = io.BytesIO(audio_data)
  100. waveform, original_sr = torchaudio.load(reference_audio, backend=self.backend)
  101. if waveform.shape[0] > 1:
  102. waveform = torch.mean(waveform, dim=0, keepdim=True)
  103. if original_sr != sr:
  104. resampler = torchaudio.transforms.Resample(
  105. orig_freq=original_sr, new_freq=sr
  106. )
  107. waveform = resampler(waveform)
  108. audio = waveform.squeeze().numpy()
  109. return audio