reference_loader.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. import io
  2. from hashlib import sha256
  3. from pathlib import Path
  4. from typing import Callable, Literal, Tuple
  5. import torch
  6. import torchaudio
  7. from loguru import logger
  8. from fish_speech.models.dac.modded_dac import DAC
  9. from fish_speech.utils.file import (
  10. AUDIO_EXTENSIONS,
  11. audio_to_bytes,
  12. list_files,
  13. read_ref_text,
  14. )
  15. from fish_speech.utils.schema import ServeReferenceAudio
  16. class ReferenceLoader:
  17. def __init__(self) -> None:
  18. """
  19. Component of the TTSInferenceEngine class.
  20. Loads and manages the cache for the reference audio and text.
  21. """
  22. self.ref_by_id: dict = {}
  23. self.ref_by_hash: dict = {}
  24. # Make Pylance happy (attribut/method not defined...)
  25. self.decoder_model: DAC
  26. self.encode_reference: Callable
  27. # Define the torchaudio backend
  28. # list_audio_backends() was removed in torchaudio 2.9
  29. try:
  30. backends = torchaudio.list_audio_backends()
  31. if "ffmpeg" in backends:
  32. self.backend = "ffmpeg"
  33. else:
  34. self.backend = "soundfile"
  35. except AttributeError:
  36. # torchaudio 2.9+ removed list_audio_backends()
  37. # Try ffmpeg first, fallback to soundfile
  38. try:
  39. import torchaudio.io._load_audio_fileobj # noqa: F401
  40. self.backend = "ffmpeg"
  41. except (ImportError, ModuleNotFoundError):
  42. self.backend = "soundfile"
  43. def load_by_id(
  44. self,
  45. id: str,
  46. use_cache: Literal["on", "off"],
  47. ) -> Tuple:
  48. # Load the references audio and text by id
  49. ref_folder = Path("references") / id
  50. ref_folder.mkdir(parents=True, exist_ok=True)
  51. ref_audios = list_files(
  52. ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
  53. )
  54. if use_cache == "off" or id not in self.ref_by_id:
  55. # If the references are not already loaded, encode them
  56. prompt_tokens = [
  57. self.encode_reference(
  58. # decoder_model=self.decoder_model,
  59. reference_audio=audio_to_bytes(str(ref_audio)),
  60. enable_reference_audio=True,
  61. )
  62. for ref_audio in ref_audios
  63. ]
  64. prompt_texts = [
  65. read_ref_text(str(ref_audio.with_suffix(".lab")))
  66. for ref_audio in ref_audios
  67. ]
  68. self.ref_by_id[id] = (prompt_tokens, prompt_texts)
  69. else:
  70. # Reuse already encoded references
  71. logger.info("Use same references")
  72. prompt_tokens, prompt_texts = self.ref_by_id[id]
  73. return prompt_tokens, prompt_texts
  74. def load_by_hash(
  75. self,
  76. references: list[ServeReferenceAudio],
  77. use_cache: Literal["on", "off"],
  78. ) -> Tuple:
  79. # Load the references audio and text by hash
  80. audio_hashes = [sha256(ref.audio).hexdigest() for ref in references]
  81. cache_used = False
  82. prompt_tokens, prompt_texts = [], []
  83. for i, ref in enumerate(references):
  84. if use_cache == "off" or audio_hashes[i] not in self.ref_by_hash:
  85. # If the references are not already loaded, encode them
  86. prompt_tokens.append(
  87. self.encode_reference(
  88. reference_audio=ref.audio,
  89. enable_reference_audio=True,
  90. )
  91. )
  92. prompt_texts.append(ref.text)
  93. self.ref_by_hash[audio_hashes[i]] = (prompt_tokens[-1], ref.text)
  94. else:
  95. # Reuse already encoded references
  96. cached_token, cached_text = self.ref_by_hash[audio_hashes[i]]
  97. prompt_tokens.append(cached_token)
  98. prompt_texts.append(cached_text)
  99. cache_used = True
  100. if cache_used:
  101. logger.info("Use same references")
  102. return prompt_tokens, prompt_texts
  103. def load_audio(self, reference_audio: bytes | str, sr: int):
  104. """
  105. Load the audio data from a file or bytes.
  106. """
  107. if len(reference_audio) > 255 or not Path(reference_audio).exists():
  108. audio_data = reference_audio
  109. reference_audio = io.BytesIO(audio_data)
  110. waveform, original_sr = torchaudio.load(reference_audio, backend=self.backend)
  111. if waveform.shape[0] > 1:
  112. waveform = torch.mean(waveform, dim=0, keepdim=True)
  113. if original_sr != sr:
  114. resampler = torchaudio.transforms.Resample(
  115. orig_freq=original_sr, new_freq=sr
  116. )
  117. waveform = resampler(waveform)
  118. audio = waveform.squeeze().numpy()
  119. return audio
  120. def list_reference_ids(self) -> list[str]:
  121. """
  122. List all valid reference IDs (subdirectory names containing valid audio and .lab files).
  123. Returns:
  124. list[str]: List of valid reference IDs
  125. """
  126. ref_base_path = Path("references")
  127. if not ref_base_path.exists():
  128. return []
  129. valid_ids = []
  130. for ref_dir in ref_base_path.iterdir():
  131. if not ref_dir.is_dir():
  132. continue
  133. # Check if directory contains at least one audio file and corresponding .lab file
  134. audio_files = list_files(
  135. ref_dir, AUDIO_EXTENSIONS, recursive=False, sort=False
  136. )
  137. if not audio_files:
  138. continue
  139. # Check if corresponding .lab file exists for at least one audio file
  140. has_valid_pair = False
  141. for audio_file in audio_files:
  142. lab_file = audio_file.with_suffix(".lab")
  143. if lab_file.exists():
  144. has_valid_pair = True
  145. break
  146. if has_valid_pair:
  147. valid_ids.append(ref_dir.name)
  148. return sorted(valid_ids)
  149. def add_reference(self, id: str, wav_file_path: str, reference_text: str) -> None:
  150. """
  151. Add a new reference voice by creating a new directory and copying files.
  152. Args:
  153. id: Reference ID (directory name)
  154. wav_file_path: Path to the audio file to copy
  155. reference_text: Text content for the .lab file
  156. Raises:
  157. FileExistsError: If the reference ID already exists
  158. FileNotFoundError: If the audio file doesn't exist
  159. OSError: If file operations fail
  160. """
  161. # Validate ID format
  162. import re
  163. if not re.match(r"^[a-zA-Z0-9\-_ ]+$", id):
  164. raise ValueError(
  165. "Reference ID contains invalid characters. Only alphanumeric, hyphens, underscores, and spaces are allowed."
  166. )
  167. if len(id) > 255:
  168. raise ValueError(
  169. "Reference ID is too long. Maximum length is 255 characters."
  170. )
  171. # Check if reference already exists
  172. ref_dir = Path("references") / id
  173. if ref_dir.exists():
  174. raise FileExistsError(f"Reference ID '{id}' already exists")
  175. # Check if audio file exists
  176. audio_path = Path(wav_file_path)
  177. if not audio_path.exists():
  178. raise FileNotFoundError(f"Audio file not found: {wav_file_path}")
  179. # Validate audio file extension
  180. if audio_path.suffix.lower() not in AUDIO_EXTENSIONS:
  181. raise ValueError(
  182. f"Unsupported audio format: {audio_path.suffix}. Supported formats: {', '.join(AUDIO_EXTENSIONS)}"
  183. )
  184. try:
  185. # Create reference directory
  186. ref_dir.mkdir(parents=True, exist_ok=False)
  187. # Determine the target audio filename with original extension
  188. target_audio_path = ref_dir / f"sample{audio_path.suffix}"
  189. # Copy audio file
  190. import shutil
  191. shutil.copy2(audio_path, target_audio_path)
  192. # Create .lab file
  193. lab_path = ref_dir / "sample.lab"
  194. with open(lab_path, "w", encoding="utf-8") as f:
  195. f.write(reference_text)
  196. # Clear cache for this ID if it exists
  197. if id in self.ref_by_id:
  198. del self.ref_by_id[id]
  199. logger.info(f"Successfully added reference voice with ID: {id}")
  200. except Exception as e:
  201. # Clean up on failure
  202. if ref_dir.exists():
  203. import shutil
  204. shutil.rmtree(ref_dir)
  205. raise e
  206. def delete_reference(self, id: str) -> None:
  207. """
  208. Delete a reference voice by removing its directory and files.
  209. Args:
  210. id: Reference ID (directory name) to delete
  211. Raises:
  212. FileNotFoundError: If the reference ID doesn't exist
  213. OSError: If file operations fail
  214. """
  215. # Check if reference exists
  216. ref_dir = Path("references") / id
  217. if not ref_dir.exists():
  218. raise FileNotFoundError(f"Reference ID '{id}' does not exist")
  219. try:
  220. # Remove the entire reference directory
  221. import shutil
  222. shutil.rmtree(ref_dir)
  223. # Clear cache for this ID if it exists
  224. if id in self.ref_by_id:
  225. del self.ref_by_id[id]
  226. logger.info(f"Successfully deleted reference voice with ID: {id}")
  227. except Exception as e:
  228. logger.error(f"Failed to delete reference '{id}': {e}")
  229. raise OSError(f"Failed to delete reference '{id}': {e}")