extract_vq.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. import os
  2. import subprocess as sp
  3. import sys
  4. import time
  5. from datetime import timedelta
  6. from functools import lru_cache
  7. from pathlib import Path
  8. from random import Random
  9. import click
  10. import numpy as np
  11. import torch
  12. import torchaudio
  13. from hydra import compose, initialize
  14. from hydra.utils import instantiate
  15. from loguru import logger
  16. from omegaconf import OmegaConf
  17. from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
  18. # register eval resolver
  19. OmegaConf.register_new_resolver("eval", eval)
  20. # This file is used to convert the audio files to text files using the Whisper model.
  21. # It's mainly used to generate the training data for the VQ model.
  22. # Determine audio backend - list_audio_backends() was removed in torchaudio 2.9
  23. try:
  24. backends = torchaudio.list_audio_backends()
  25. if "ffmpeg" in backends:
  26. backend = "ffmpeg"
  27. else:
  28. backend = "soundfile"
  29. except AttributeError:
  30. # torchaudio 2.9+ removed list_audio_backends()
  31. # Try ffmpeg first, fallback to soundfile
  32. try:
  33. import torchaudio.io._load_audio_fileobj # Check if ffmpeg backend is available
  34. backend = "ffmpeg"
  35. except (ImportError, ModuleNotFoundError):
  36. backend = "soundfile"
  37. RANK = int(os.environ.get("SLURM_PROCID", 0))
  38. WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
  39. logger_format = (
  40. "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
  41. "<level>{level: <8}</level> | "
  42. "<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
  43. "{extra[rank]} - <level>{message}</level>"
  44. )
  45. logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
  46. logger.remove()
  47. logger.add(sys.stderr, format=logger_format)
  48. @lru_cache(maxsize=1)
  49. def get_model(
  50. config_name: str = "modded_dac_vq",
  51. checkpoint_path: str = "checkpoints/openaudio-s1-mini/codec.pth",
  52. device: str | torch.device = "cuda",
  53. ):
  54. with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
  55. cfg = compose(config_name=config_name)
  56. model = instantiate(cfg)
  57. state_dict = torch.load(
  58. checkpoint_path,
  59. map_location=device,
  60. )
  61. if "state_dict" in state_dict:
  62. state_dict = state_dict["state_dict"]
  63. if any("generator" in k for k in state_dict):
  64. state_dict = {
  65. k.replace("generator.", ""): v
  66. for k, v in state_dict.items()
  67. if "generator." in k
  68. }
  69. model.load_state_dict(state_dict, strict=False)
  70. model.eval()
  71. model.to(device)
  72. logger.info(f"Loaded model")
  73. return model
  74. @torch.inference_mode()
  75. def process_batch(files: list[Path], model) -> float:
  76. wavs = []
  77. audio_lengths = []
  78. new_files = []
  79. max_length = total_time = 0
  80. for file in files:
  81. try:
  82. wav, sr = torchaudio.load(
  83. str(file), backend=backend
  84. ) # Need to install libsox-dev
  85. except Exception as e:
  86. logger.error(f"Error reading {file}: {e}")
  87. continue
  88. if wav.shape[0] > 1:
  89. wav = wav.mean(dim=0, keepdim=True)
  90. wav = torchaudio.functional.resample(wav.cuda(), sr, model.sample_rate)[0]
  91. total_time += len(wav) / model.sample_rate
  92. max_length = max(max_length, len(wav))
  93. wavs.append(wav)
  94. audio_lengths.append(len(wav))
  95. new_files.append(file)
  96. files = new_files
  97. # Pad to max length
  98. for i, wav in enumerate(wavs):
  99. wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
  100. audios = torch.stack(wavs, dim=0)[:, None]
  101. audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
  102. # Calculate lengths
  103. indices, feature_lengths = model.encode(audios, audio_lengths)
  104. # Save to disk
  105. outputs = indices.cpu().numpy()
  106. for file, length, feature, audio_length in zip(
  107. files, feature_lengths, outputs, audio_lengths
  108. ):
  109. feature = feature[:, :length]
  110. # (T,)
  111. with open(file.with_suffix(".npy"), "wb") as f:
  112. np.save(f, feature)
  113. return total_time
  114. @click.command()
  115. @click.argument("folder")
  116. @click.option("--num-workers", default=1)
  117. @click.option("--config-name", default="modded_dac_vq")
  118. @click.option(
  119. "--checkpoint-path",
  120. default="checkpoints/s2-pro/codec.pth",
  121. )
  122. @click.option("--batch-size", default=64)
  123. @click.option("--filelist", default=None, type=Path)
  124. def main(
  125. folder: str,
  126. num_workers: int,
  127. config_name: str,
  128. checkpoint_path: str,
  129. batch_size: int,
  130. filelist: Path,
  131. ):
  132. if num_workers > 1 and WORLD_SIZE != num_workers:
  133. assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
  134. logger.info(f"Spawning {num_workers} workers")
  135. if torch.cuda.is_available():
  136. visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
  137. if visible_devices is None:
  138. visible_devices = list(range(torch.cuda.device_count()))
  139. else:
  140. visible_devices = visible_devices.split(",")
  141. else:
  142. # Set to empty string to avoid using GPU
  143. visible_devices = [""]
  144. processes = []
  145. for i in range(num_workers):
  146. env = os.environ.copy()
  147. env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
  148. env["SLURM_PROCID"] = str(i)
  149. env["SLURM_NTASKS"] = str(num_workers)
  150. processes.append(
  151. sp.Popen(
  152. [sys.executable] + sys.argv.copy(),
  153. env=env,
  154. )
  155. )
  156. for p in processes:
  157. p.wait()
  158. logger.info(f"All workers finished")
  159. return
  160. # This is a worker
  161. logger.info(f"Starting worker")
  162. if filelist:
  163. files = [i[0] for i in load_filelist(filelist)]
  164. else:
  165. files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False)
  166. print(f"Found {len(files)} files")
  167. files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()]
  168. total_files = len(files)
  169. files = files[RANK::WORLD_SIZE]
  170. logger.info(f"Processing {len(files)}/{total_files} files")
  171. # Batch processing
  172. total_time = 0
  173. begin_time = time.time()
  174. processed_files = 0
  175. model = get_model(config_name, checkpoint_path)
  176. for n_batch, idx in enumerate(range(0, len(files), batch_size)):
  177. batch = files[idx : idx + batch_size]
  178. batch_time = process_batch(batch, model)
  179. total_time += batch_time
  180. processed_files += len(batch)
  181. if (n_batch + 1) % 10 == 0:
  182. eta = (
  183. (time.time() - begin_time)
  184. / processed_files
  185. * (len(files) - processed_files)
  186. )
  187. logger.info(
  188. f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
  189. + f"ETA: {timedelta(seconds=round(eta))}s"
  190. )
  191. logger.info(
  192. f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
  193. )
  194. if __name__ == "__main__":
  195. main()