extract_vq.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. import os
  2. import subprocess as sp
  3. import sys
  4. import time
  5. from datetime import timedelta
  6. from functools import lru_cache
  7. from pathlib import Path
  8. from random import Random
  9. import click
  10. import numpy as np
  11. import torch
  12. import torchaudio
  13. from einops import rearrange
  14. from hydra import compose, initialize
  15. from hydra.utils import instantiate
  16. from lightning import LightningModule
  17. from loguru import logger
  18. from omegaconf import OmegaConf
  19. from fish_speech.models.vqgan.utils import sequence_mask
  20. from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
  21. # register eval resolver
  22. OmegaConf.register_new_resolver("eval", eval)
  23. # This file is used to convert the audio files to text files using the Whisper model.
  24. # It's mainly used to generate the training data for the VQ model.
  25. RANK = int(os.environ.get("SLURM_PROCID", 0))
  26. WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
  27. logger_format = (
  28. "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
  29. "<level>{level: <8}</level> | "
  30. "<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
  31. "{extra[rank]} - <level>{message}</level>"
  32. )
  33. logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
  34. logger.remove()
  35. logger.add(sys.stderr, format=logger_format)
  36. @lru_cache(maxsize=1)
  37. def get_model(
  38. config_name: str = "vqgan",
  39. checkpoint_path: str = "checkpoints/vqgan/step_000380000.ckpt",
  40. ):
  41. with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
  42. cfg = compose(config_name=config_name)
  43. model: LightningModule = instantiate(cfg.model)
  44. state_dict = torch.load(
  45. checkpoint_path,
  46. map_location=model.device,
  47. )
  48. if "state_dict" in state_dict:
  49. state_dict = state_dict["state_dict"]
  50. model.load_state_dict(state_dict, strict=True)
  51. model.eval()
  52. model.cuda()
  53. logger.info(f"Loaded model")
  54. return model
  55. def process_batch(files: list[Path], model) -> float:
  56. wavs = []
  57. audio_lengths = []
  58. max_length = total_time = 0
  59. for file in files:
  60. wav, sr = torchaudio.load(file)
  61. if wav.shape[0] > 1:
  62. wav = wav.mean(dim=0, keepdim=True)
  63. wav = torchaudio.functional.resample(wav.cuda(), sr, model.sampling_rate)[0]
  64. wavs.append(wav)
  65. total_time += len(wav) / model.sampling_rate
  66. max_length = max(max_length, len(wav))
  67. audio_lengths.append(len(wav))
  68. # Pad to max length
  69. for i, wav in enumerate(wavs):
  70. wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
  71. audios = torch.stack(wavs, dim=0)[:, None]
  72. audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
  73. # Calculate lengths
  74. with torch.no_grad():
  75. # VQ Encoder
  76. features = gt_mels = model.mel_transform(
  77. audios, sample_rate=model.sampling_rate
  78. )
  79. if model.downsample is not None:
  80. features = model.downsample(features)
  81. feature_lengths = (
  82. audio_lengths
  83. / model.hop_length
  84. / (model.downsample.total_strides if model.downsample is not None else 1)
  85. ).long()
  86. feature_masks = torch.unsqueeze(
  87. sequence_mask(feature_lengths, features.shape[2]), 1
  88. ).to(gt_mels.dtype)
  89. text_features = model.mel_encoder(features, feature_masks)
  90. _, indices, _ = model.vq_encoder(text_features, feature_masks)
  91. if indices.ndim == 4:
  92. # Grouped vq
  93. assert indices.shape[-1] == 1, f"Residual vq is not supported"
  94. indices = indices.squeeze(-1)
  95. elif indices.ndim == 2:
  96. # Single vq
  97. indices = indices.unsqueeze(0)
  98. else:
  99. raise ValueError(f"Invalid indices shape {indices.shape}")
  100. indices = rearrange(indices, "c b t -> b c t")
  101. # Save to disk
  102. outputs = indices.cpu().numpy()
  103. for file, length, feature, audio in zip(files, feature_lengths, outputs, audios):
  104. feature = feature[:, :length]
  105. # (T,)
  106. with open(file.with_suffix(".npy"), "wb") as f:
  107. np.save(f, feature)
  108. return total_time
  109. @click.command()
  110. @click.argument("folder")
  111. @click.option("--num-workers", default=1)
  112. @click.option("--config-name", default="vqgan_pretrain")
  113. @click.option(
  114. "--checkpoint-path",
  115. default="checkpoints/vqgan-v1.pth",
  116. )
  117. @click.option("--batch-size", default=64)
  118. @click.option("--filelist", default=None, type=Path)
  119. def main(
  120. folder: str,
  121. num_workers: int,
  122. config_name: str,
  123. checkpoint_path: str,
  124. batch_size: int,
  125. filelist: Path,
  126. ):
  127. if num_workers > 1 and WORLD_SIZE != num_workers:
  128. assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
  129. logger.info(f"Spawning {num_workers} workers")
  130. visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
  131. if visible_devices is None:
  132. visible_devices = list(range(torch.cuda.device_count()))
  133. else:
  134. visible_devices = visible_devices.split(",")
  135. processes = []
  136. for i in range(num_workers):
  137. env = os.environ.copy()
  138. env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
  139. env["SLURM_PROCID"] = str(i)
  140. env["SLURM_NTASKS"] = str(num_workers)
  141. processes.append(
  142. sp.Popen(
  143. [sys.executable] + sys.argv.copy(),
  144. env=env,
  145. )
  146. )
  147. for p in processes:
  148. p.wait()
  149. logger.info(f"All workers finished")
  150. return
  151. # This is a worker
  152. logger.info(f"Starting worker")
  153. if filelist:
  154. files = [i[0] for i in load_filelist(filelist)]
  155. else:
  156. files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=True)
  157. Random(42).shuffle(files)
  158. total_files = len(files)
  159. files = files[RANK::WORLD_SIZE]
  160. logger.info(f"Processing {len(files)}/{total_files} files")
  161. # Batch processing
  162. total_time = 0
  163. begin_time = time.time()
  164. processed_files = 0
  165. model = get_model(config_name, checkpoint_path)
  166. for n_batch, idx in enumerate(range(0, len(files), batch_size)):
  167. batch = files[idx : idx + batch_size]
  168. batch_time = process_batch(batch, model)
  169. total_time += batch_time
  170. processed_files += len(batch)
  171. if (n_batch + 1) % 10 == 0:
  172. eta = (
  173. (time.time() - begin_time)
  174. / processed_files
  175. * (len(files) - processed_files)
  176. )
  177. logger.info(
  178. f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
  179. + f"ETA: {timedelta(seconds=round(eta))}s"
  180. )
  181. logger.info(
  182. f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
  183. )
  184. if __name__ == "__main__":
  185. main()