extract_vq.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. import os
  2. import subprocess as sp
  3. import sys
  4. import time
  5. from datetime import timedelta
  6. from functools import lru_cache
  7. from pathlib import Path
  8. from random import Random
  9. import click
  10. import numpy as np
  11. import torch
  12. import torchaudio
  13. from einops import rearrange
  14. from hydra import compose, initialize
  15. from hydra.utils import instantiate
  16. from lightning import LightningModule
  17. from loguru import logger
  18. from omegaconf import OmegaConf
  19. from fish_speech.models.vqgan.utils import sequence_mask
  20. from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
  21. # register eval resolver
  22. OmegaConf.register_new_resolver("eval", eval)
  23. # This file is used to convert the audio files to text files using the Whisper model.
  24. # It's mainly used to generate the training data for the VQ model.
  25. RANK = int(os.environ.get("SLURM_PROCID", 0))
  26. WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
  27. logger_format = (
  28. "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
  29. "<level>{level: <8}</level> | "
  30. "<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
  31. "{extra[rank]} - <level>{message}</level>"
  32. )
  33. logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
  34. logger.remove()
  35. logger.add(sys.stderr, format=logger_format)
  36. @lru_cache(maxsize=1)
  37. def get_model(
  38. config_name: str = "vqgan",
  39. checkpoint_path: str = "checkpoints/vqgan/step_000380000.ckpt",
  40. ):
  41. with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
  42. cfg = compose(config_name=config_name)
  43. model: LightningModule = instantiate(cfg.model)
  44. state_dict = torch.load(
  45. checkpoint_path,
  46. map_location=model.device,
  47. )["state_dict"]
  48. model.load_state_dict(state_dict, strict=True)
  49. model.eval()
  50. model.cuda()
  51. logger.info(f"Loaded model")
  52. return model
  53. def process_batch(files: list[Path], model) -> float:
  54. wavs = []
  55. audio_lengths = []
  56. max_length = total_time = 0
  57. for file in files:
  58. wav, sr = torchaudio.load(file)
  59. if wav.shape[0] > 1:
  60. wav = wav.mean(dim=0, keepdim=True)
  61. wav = torchaudio.functional.resample(wav.cuda(), sr, model.sampling_rate)[0]
  62. wavs.append(wav)
  63. total_time += len(wav) / model.sampling_rate
  64. max_length = max(max_length, len(wav))
  65. audio_lengths.append(len(wav))
  66. # Pad to max length
  67. for i, wav in enumerate(wavs):
  68. wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
  69. audios = torch.stack(wavs, dim=0)[:, None]
  70. audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
  71. # Calculate lengths
  72. with torch.no_grad():
  73. # VQ Encoder
  74. features = gt_mels = model.mel_transform(
  75. audios, sample_rate=model.sampling_rate
  76. )
  77. if model.downsample is not None:
  78. features = model.downsample(features)
  79. feature_lengths = (
  80. audio_lengths
  81. / model.hop_length
  82. / (model.downsample.total_strides if model.downsample is not None else 1)
  83. ).long()
  84. feature_masks = torch.unsqueeze(
  85. sequence_mask(feature_lengths, features.shape[2]), 1
  86. ).to(gt_mels.dtype)
  87. text_features = model.mel_encoder(features, feature_masks)
  88. _, indices, _ = model.vq_encoder(text_features, feature_masks)
  89. if indices.ndim == 4:
  90. # Grouped vq
  91. assert indices.shape[-1] == 1, f"Residual vq is not supported"
  92. indices = indices.squeeze(-1)
  93. elif indices.ndim == 2:
  94. # Single vq
  95. indices = indices.unsqueeze(0)
  96. else:
  97. raise ValueError(f"Invalid indices shape {indices.shape}")
  98. indices = rearrange(indices, "c b t -> b c t")
  99. # Save to disk
  100. outputs = indices.cpu().numpy()
  101. for file, length, feature, audio in zip(files, feature_lengths, outputs, audios):
  102. feature = feature[:, :length]
  103. # (T,)
  104. with open(file.with_suffix(".npy"), "wb") as f:
  105. np.save(f, feature)
  106. return total_time
  107. @click.command()
  108. @click.argument("folder")
  109. @click.option("--num-workers", default=1)
  110. @click.option("--config-name", default="vqgan")
  111. @click.option(
  112. "--checkpoint-path",
  113. default="checkpoints/vqgan/step_000380000.ckpt",
  114. )
  115. @click.option("--batch-size", default=64)
  116. def main(
  117. folder: str,
  118. num_workers: int,
  119. config_name: str,
  120. checkpoint_path: str,
  121. batch_size: int,
  122. ):
  123. if num_workers > 1 and WORLD_SIZE != num_workers:
  124. assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
  125. logger.info(f"Spawning {num_workers} workers")
  126. visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
  127. if visible_devices is None:
  128. visible_devices = list(range(torch.cuda.device_count()))
  129. else:
  130. visible_devices = visible_devices.split(",")
  131. processes = []
  132. for i in range(num_workers):
  133. env = os.environ.copy()
  134. env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
  135. env["SLURM_PROCID"] = str(i)
  136. env["SLURM_NTASKS"] = str(num_workers)
  137. processes.append(
  138. sp.Popen(
  139. [sys.executable] + sys.argv.copy(),
  140. env=env,
  141. )
  142. )
  143. for p in processes:
  144. p.wait()
  145. logger.info(f"All workers finished")
  146. return
  147. # This is a worker
  148. logger.info(f"Starting worker")
  149. files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=True)
  150. Random(42).shuffle(files)
  151. total_files = len(files)
  152. files = files[RANK::WORLD_SIZE]
  153. logger.info(f"Processing {len(files)}/{total_files} files")
  154. # Batch processing
  155. total_time = 0
  156. begin_time = time.time()
  157. processed_files = 0
  158. model = get_model(config_name, checkpoint_path)
  159. for n_batch, idx in enumerate(range(0, len(files), batch_size)):
  160. batch = files[idx : idx + batch_size]
  161. batch_time = process_batch(batch, model)
  162. total_time += batch_time
  163. processed_files += len(batch)
  164. if (n_batch + 1) % 10 == 0:
  165. eta = (
  166. (time.time() - begin_time)
  167. / processed_files
  168. * (len(files) - processed_files)
  169. )
  170. logger.info(
  171. f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
  172. + f"ETA: {timedelta(seconds=round(eta))}s"
  173. )
  174. logger.info(
  175. f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
  176. )
  177. if __name__ == "__main__":
  178. main()