reference_loader.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. import io
  2. import re
  3. from hashlib import sha256
  4. from pathlib import Path
  5. from typing import Callable, Literal, Tuple
  6. import torch
  7. import torchaudio
  8. from loguru import logger
  9. from fish_speech.models.dac.modded_dac import DAC
  10. from fish_speech.utils.file import (
  11. AUDIO_EXTENSIONS,
  12. audio_to_bytes,
  13. list_files,
  14. read_ref_text,
  15. )
  16. from fish_speech.utils.schema import ServeReferenceAudio
  17. _ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_ ]+$")
  18. class ReferenceLoader:
  19. def __init__(self) -> None:
  20. """
  21. Component of the TTSInferenceEngine class.
  22. Loads and manages the cache for the reference audio and text.
  23. """
  24. self.ref_by_id: dict = {}
  25. self.ref_by_hash: dict = {}
  26. # Make Pylance happy (attribut/method not defined...)
  27. self.decoder_model: DAC
  28. self.encode_reference: Callable
  29. # Define the torchaudio backend
  30. # list_audio_backends() was removed in torchaudio 2.9
  31. try:
  32. backends = torchaudio.list_audio_backends()
  33. if "ffmpeg" in backends:
  34. self.backend = "ffmpeg"
  35. else:
  36. self.backend = "soundfile"
  37. except AttributeError:
  38. # torchaudio 2.9+ removed list_audio_backends()
  39. # Try ffmpeg first, fallback to soundfile
  40. try:
  41. __import__("torchaudio.io._load_audio_fileobj")
  42. self.backend = "ffmpeg"
  43. except (ImportError, ModuleNotFoundError):
  44. self.backend = "soundfile"
  45. @staticmethod
  46. def _validate_id(id: str) -> None:
  47. if not _ID_PATTERN.match(id) or len(id) > 255:
  48. raise ValueError(
  49. "Reference ID contains invalid characters or is too long. "
  50. "Only alphanumeric, hyphens, underscores, and spaces are allowed."
  51. )
  52. def load_by_id(
  53. self,
  54. id: str,
  55. use_cache: Literal["on", "off"],
  56. ) -> Tuple:
  57. self._validate_id(id)
  58. # Load the references audio and text by id
  59. ref_folder = Path("references") / id
  60. ref_folder.mkdir(parents=True, exist_ok=True)
  61. ref_audios = list_files(
  62. ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
  63. )
  64. if use_cache == "off" or id not in self.ref_by_id:
  65. # If the references are not already loaded, encode them
  66. prompt_tokens = [
  67. self.encode_reference(
  68. # decoder_model=self.decoder_model,
  69. reference_audio=audio_to_bytes(str(ref_audio)),
  70. enable_reference_audio=True,
  71. )
  72. for ref_audio in ref_audios
  73. ]
  74. prompt_texts = [
  75. read_ref_text(str(ref_audio.with_suffix(".lab")))
  76. for ref_audio in ref_audios
  77. ]
  78. self.ref_by_id[id] = (prompt_tokens, prompt_texts)
  79. else:
  80. # Reuse already encoded references
  81. logger.info("Use same references")
  82. prompt_tokens, prompt_texts = self.ref_by_id[id]
  83. return prompt_tokens, prompt_texts
  84. def load_by_hash(
  85. self,
  86. references: list[ServeReferenceAudio],
  87. use_cache: Literal["on", "off"],
  88. ) -> Tuple:
  89. # Load the references audio and text by hash
  90. audio_hashes = [sha256(ref.audio).hexdigest() for ref in references]
  91. cache_used = False
  92. prompt_tokens, prompt_texts = [], []
  93. for i, ref in enumerate(references):
  94. if use_cache == "off" or audio_hashes[i] not in self.ref_by_hash:
  95. # If the references are not already loaded, encode them
  96. prompt_tokens.append(
  97. self.encode_reference(
  98. reference_audio=ref.audio,
  99. enable_reference_audio=True,
  100. )
  101. )
  102. prompt_texts.append(ref.text)
  103. self.ref_by_hash[audio_hashes[i]] = (prompt_tokens[-1], ref.text)
  104. else:
  105. # Reuse already encoded references
  106. cached_token, cached_text = self.ref_by_hash[audio_hashes[i]]
  107. prompt_tokens.append(cached_token)
  108. prompt_texts.append(cached_text)
  109. cache_used = True
  110. if cache_used:
  111. logger.info("Use same references")
  112. return prompt_tokens, prompt_texts
  113. def load_audio(self, reference_audio: bytes | str, sr: int):
  114. """
  115. Load the audio data from a file or bytes.
  116. """
  117. if len(reference_audio) > 255 or not Path(reference_audio).exists():
  118. audio_data = reference_audio
  119. reference_audio = io.BytesIO(audio_data)
  120. waveform, original_sr = torchaudio.load(reference_audio, backend=self.backend)
  121. if waveform.shape[0] > 1:
  122. waveform = torch.mean(waveform, dim=0, keepdim=True)
  123. if original_sr != sr:
  124. resampler = torchaudio.transforms.Resample(
  125. orig_freq=original_sr, new_freq=sr
  126. )
  127. waveform = resampler(waveform)
  128. audio = waveform.squeeze().numpy()
  129. return audio
  130. def list_reference_ids(self) -> list[str]:
  131. """
  132. List all valid reference IDs (subdirectory names containing valid audio and .lab files).
  133. Returns:
  134. list[str]: List of valid reference IDs
  135. """
  136. ref_base_path = Path("references")
  137. if not ref_base_path.exists():
  138. return []
  139. valid_ids = []
  140. for ref_dir in ref_base_path.iterdir():
  141. if not ref_dir.is_dir():
  142. continue
  143. # Check if directory contains at least one audio file and corresponding .lab file
  144. audio_files = list_files(
  145. ref_dir, AUDIO_EXTENSIONS, recursive=False, sort=False
  146. )
  147. if not audio_files:
  148. continue
  149. # Check if corresponding .lab file exists for at least one audio file
  150. has_valid_pair = False
  151. for audio_file in audio_files:
  152. lab_file = audio_file.with_suffix(".lab")
  153. if lab_file.exists():
  154. has_valid_pair = True
  155. break
  156. if has_valid_pair:
  157. valid_ids.append(ref_dir.name)
  158. return sorted(valid_ids)
  159. def add_reference(self, id: str, wav_file_path: str, reference_text: str) -> None:
  160. """
  161. Add a new reference voice by creating a new directory and copying files.
  162. Args:
  163. id: Reference ID (directory name)
  164. wav_file_path: Path to the audio file to copy
  165. reference_text: Text content for the .lab file
  166. Raises:
  167. FileExistsError: If the reference ID already exists
  168. FileNotFoundError: If the audio file doesn't exist
  169. OSError: If file operations fail
  170. """
  171. self._validate_id(id)
  172. # Check if reference already exists
  173. ref_dir = Path("references") / id
  174. if ref_dir.exists():
  175. raise FileExistsError(f"Reference ID '{id}' already exists")
  176. # Check if audio file exists
  177. audio_path = Path(wav_file_path)
  178. if not audio_path.exists():
  179. raise FileNotFoundError(f"Audio file not found: {wav_file_path}")
  180. # Validate audio file extension
  181. if audio_path.suffix.lower() not in AUDIO_EXTENSIONS:
  182. raise ValueError(
  183. f"Unsupported audio format: {audio_path.suffix}. Supported formats: {', '.join(AUDIO_EXTENSIONS)}"
  184. )
  185. try:
  186. # Create reference directory
  187. ref_dir.mkdir(parents=True, exist_ok=False)
  188. # Determine the target audio filename with original extension
  189. target_audio_path = ref_dir / f"sample{audio_path.suffix}"
  190. # Copy audio file
  191. import shutil
  192. shutil.copy2(audio_path, target_audio_path)
  193. # Create .lab file
  194. lab_path = ref_dir / "sample.lab"
  195. with open(lab_path, "w", encoding="utf-8") as f:
  196. f.write(reference_text)
  197. # Clear cache for this ID if it exists
  198. if id in self.ref_by_id:
  199. del self.ref_by_id[id]
  200. logger.info(f"Successfully added reference voice with ID: {id}")
  201. except Exception as e:
  202. # Clean up on failure
  203. if ref_dir.exists():
  204. import shutil
  205. shutil.rmtree(ref_dir)
  206. raise e
  207. def delete_reference(self, id: str) -> None:
  208. """
  209. Delete a reference voice by removing its directory and files.
  210. Args:
  211. id: Reference ID (directory name) to delete
  212. Raises:
  213. FileNotFoundError: If the reference ID doesn't exist
  214. OSError: If file operations fail
  215. """
  216. self._validate_id(id)
  217. ref_dir = Path("references") / id
  218. if not ref_dir.exists():
  219. raise FileNotFoundError(f"Reference ID '{id}' does not exist")
  220. try:
  221. # Remove the entire reference directory
  222. import shutil
  223. shutil.rmtree(ref_dir)
  224. # Clear cache for this ID if it exists
  225. if id in self.ref_by_id:
  226. del self.ref_by_id[id]
  227. logger.info(f"Successfully deleted reference voice with ID: {id}")
  228. except Exception as e:
  229. logger.error(f"Failed to delete reference '{id}': {e}")
  230. raise OSError(f"Failed to delete reference '{id}': {e}")