reference_loader.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. import io
  2. from hashlib import sha256
  3. from pathlib import Path
  4. from typing import Callable, Literal, Tuple
  5. import torch
  6. import torchaudio
  7. from loguru import logger
  8. from fish_speech.models.dac.modded_dac import DAC
  9. from fish_speech.utils.file import (
  10. AUDIO_EXTENSIONS,
  11. audio_to_bytes,
  12. list_files,
  13. read_ref_text,
  14. )
  15. from fish_speech.utils.schema import ServeReferenceAudio
  16. class ReferenceLoader:
  17. def __init__(self) -> None:
  18. """
  19. Component of the TTSInferenceEngine class.
  20. Loads and manages the cache for the reference audio and text.
  21. """
  22. self.ref_by_id: dict = {}
  23. self.ref_by_hash: dict = {}
  24. # Make Pylance happy (attribut/method not defined...)
  25. self.decoder_model: DAC
  26. self.encode_reference: Callable
  27. # Define the torchaudio backend
  28. backends = torchaudio.list_audio_backends()
  29. if "ffmpeg" in backends:
  30. self.backend = "ffmpeg"
  31. else:
  32. self.backend = "soundfile"
  33. def load_by_id(
  34. self,
  35. id: str,
  36. use_cache: Literal["on", "off"],
  37. ) -> Tuple:
  38. # Load the references audio and text by id
  39. ref_folder = Path("references") / id
  40. ref_folder.mkdir(parents=True, exist_ok=True)
  41. ref_audios = list_files(
  42. ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
  43. )
  44. if use_cache == "off" or id not in self.ref_by_id:
  45. # If the references are not already loaded, encode them
  46. prompt_tokens = [
  47. self.encode_reference(
  48. # decoder_model=self.decoder_model,
  49. reference_audio=audio_to_bytes(str(ref_audio)),
  50. enable_reference_audio=True,
  51. )
  52. for ref_audio in ref_audios
  53. ]
  54. prompt_texts = [
  55. read_ref_text(str(ref_audio.with_suffix(".lab")))
  56. for ref_audio in ref_audios
  57. ]
  58. self.ref_by_id[id] = (prompt_tokens, prompt_texts)
  59. else:
  60. # Reuse already encoded references
  61. logger.info("Use same references")
  62. prompt_tokens, prompt_texts = self.ref_by_id[id]
  63. return prompt_tokens, prompt_texts
  64. def load_by_hash(
  65. self,
  66. references: list[ServeReferenceAudio],
  67. use_cache: Literal["on", "off"],
  68. ) -> Tuple:
  69. # Load the references audio and text by hash
  70. audio_hashes = [sha256(ref.audio).hexdigest() for ref in references]
  71. cache_used = False
  72. prompt_tokens, prompt_texts = [], []
  73. for i, ref in enumerate(references):
  74. if use_cache == "off" or audio_hashes[i] not in self.ref_by_hash:
  75. # If the references are not already loaded, encode them
  76. prompt_tokens.append(
  77. self.encode_reference(
  78. reference_audio=ref.audio,
  79. enable_reference_audio=True,
  80. )
  81. )
  82. prompt_texts.append(ref.text)
  83. self.ref_by_hash[audio_hashes[i]] = (prompt_tokens[-1], ref.text)
  84. else:
  85. # Reuse already encoded references
  86. cached_token, cached_text = self.ref_by_hash[audio_hashes[i]]
  87. prompt_tokens.append(cached_token)
  88. prompt_texts.append(cached_text)
  89. cache_used = True
  90. if cache_used:
  91. logger.info("Use same references")
  92. return prompt_tokens, prompt_texts
  93. def load_audio(self, reference_audio: bytes | str, sr: int):
  94. """
  95. Load the audio data from a file or bytes.
  96. """
  97. if len(reference_audio) > 255 or not Path(reference_audio).exists():
  98. audio_data = reference_audio
  99. reference_audio = io.BytesIO(audio_data)
  100. waveform, original_sr = torchaudio.load(reference_audio, backend=self.backend)
  101. if waveform.shape[0] > 1:
  102. waveform = torch.mean(waveform, dim=0, keepdim=True)
  103. if original_sr != sr:
  104. resampler = torchaudio.transforms.Resample(
  105. orig_freq=original_sr, new_freq=sr
  106. )
  107. waveform = resampler(waveform)
  108. audio = waveform.squeeze().numpy()
  109. return audio
  110. def list_reference_ids(self) -> list[str]:
  111. """
  112. List all valid reference IDs (subdirectory names containing valid audio and .lab files).
  113. Returns:
  114. list[str]: List of valid reference IDs
  115. """
  116. ref_base_path = Path("references")
  117. if not ref_base_path.exists():
  118. return []
  119. valid_ids = []
  120. for ref_dir in ref_base_path.iterdir():
  121. if not ref_dir.is_dir():
  122. continue
  123. # Check if directory contains at least one audio file and corresponding .lab file
  124. audio_files = list_files(
  125. ref_dir, AUDIO_EXTENSIONS, recursive=False, sort=False
  126. )
  127. if not audio_files:
  128. continue
  129. # Check if corresponding .lab file exists for at least one audio file
  130. has_valid_pair = False
  131. for audio_file in audio_files:
  132. lab_file = audio_file.with_suffix(".lab")
  133. if lab_file.exists():
  134. has_valid_pair = True
  135. break
  136. if has_valid_pair:
  137. valid_ids.append(ref_dir.name)
  138. return sorted(valid_ids)
  139. def add_reference(self, id: str, wav_file_path: str, reference_text: str) -> None:
  140. """
  141. Add a new reference voice by creating a new directory and copying files.
  142. Args:
  143. id: Reference ID (directory name)
  144. wav_file_path: Path to the audio file to copy
  145. reference_text: Text content for the .lab file
  146. Raises:
  147. FileExistsError: If the reference ID already exists
  148. FileNotFoundError: If the audio file doesn't exist
  149. OSError: If file operations fail
  150. """
  151. # Validate ID format
  152. import re
  153. if not re.match(r"^[a-zA-Z0-9\-_ ]+$", id):
  154. raise ValueError(
  155. "Reference ID contains invalid characters. Only alphanumeric, hyphens, underscores, and spaces are allowed."
  156. )
  157. if len(id) > 255:
  158. raise ValueError(
  159. "Reference ID is too long. Maximum length is 255 characters."
  160. )
  161. # Check if reference already exists
  162. ref_dir = Path("references") / id
  163. if ref_dir.exists():
  164. raise FileExistsError(f"Reference ID '{id}' already exists")
  165. # Check if audio file exists
  166. audio_path = Path(wav_file_path)
  167. if not audio_path.exists():
  168. raise FileNotFoundError(f"Audio file not found: {wav_file_path}")
  169. # Validate audio file extension
  170. if audio_path.suffix.lower() not in AUDIO_EXTENSIONS:
  171. raise ValueError(
  172. f"Unsupported audio format: {audio_path.suffix}. Supported formats: {', '.join(AUDIO_EXTENSIONS)}"
  173. )
  174. try:
  175. # Create reference directory
  176. ref_dir.mkdir(parents=True, exist_ok=False)
  177. # Determine the target audio filename with original extension
  178. target_audio_path = ref_dir / f"sample{audio_path.suffix}"
  179. # Copy audio file
  180. import shutil
  181. shutil.copy2(audio_path, target_audio_path)
  182. # Create .lab file
  183. lab_path = ref_dir / "sample.lab"
  184. with open(lab_path, "w", encoding="utf-8") as f:
  185. f.write(reference_text)
  186. # Clear cache for this ID if it exists
  187. if id in self.ref_by_id:
  188. del self.ref_by_id[id]
  189. logger.info(f"Successfully added reference voice with ID: {id}")
  190. except Exception as e:
  191. # Clean up on failure
  192. if ref_dir.exists():
  193. import shutil
  194. shutil.rmtree(ref_dir)
  195. raise e
  196. def delete_reference(self, id: str) -> None:
  197. """
  198. Delete a reference voice by removing its directory and files.
  199. Args:
  200. id: Reference ID (directory name) to delete
  201. Raises:
  202. FileNotFoundError: If the reference ID doesn't exist
  203. OSError: If file operations fail
  204. """
  205. # Check if reference exists
  206. ref_dir = Path("references") / id
  207. if not ref_dir.exists():
  208. raise FileNotFoundError(f"Reference ID '{id}' does not exist")
  209. try:
  210. # Remove the entire reference directory
  211. import shutil
  212. shutil.rmtree(ref_dir)
  213. # Clear cache for this ID if it exists
  214. if id in self.ref_by_id:
  215. del self.ref_by_id[id]
  216. logger.info(f"Successfully deleted reference voice with ID: {id}")
  217. except Exception as e:
  218. logger.error(f"Failed to delete reference '{id}': {e}")
  219. raise OSError(f"Failed to delete reference '{id}': {e}")