extract_vq.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. import os
  2. import subprocess as sp
  3. import sys
  4. import time
  5. from datetime import timedelta
  6. from functools import lru_cache
  7. from pathlib import Path
  8. from random import Random
  9. import click
  10. import numpy as np
  11. import torch
  12. import torchaudio
  13. from einops import rearrange
  14. from hydra import compose, initialize
  15. from hydra.utils import instantiate
  16. from lightning import LightningModule
  17. from loguru import logger
  18. from omegaconf import OmegaConf
  19. from fish_speech.models.vqgan.utils import sequence_mask
  20. from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
  21. # register eval resolver
  22. OmegaConf.register_new_resolver("eval", eval)
  23. # This file is used to convert the audio files to text files using the Whisper model.
  24. # It's mainly used to generate the training data for the VQ model.
  25. RANK = int(os.environ.get("SLURM_PROCID", 0))
  26. WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
  27. logger_format = (
  28. "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
  29. "<level>{level: <8}</level> | "
  30. "<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
  31. "{extra[rank]} - <level>{message}</level>"
  32. )
  33. logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
  34. logger.remove()
  35. logger.add(sys.stderr, format=logger_format)
  36. @lru_cache(maxsize=1)
  37. def get_model():
  38. with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
  39. cfg = compose(config_name="vqgan")
  40. model: LightningModule = instantiate(cfg.model)
  41. state_dict = torch.load(
  42. "checkpoints/vqgan/step_000380000.ckpt",
  43. map_location=model.device,
  44. )["state_dict"]
  45. model.load_state_dict(state_dict, strict=True)
  46. model.eval()
  47. model.cuda()
  48. logger.info("Restored model from checkpoint")
  49. logger.info(f"Loaded model")
  50. return model
  51. def process_batch(files: list[Path]) -> float:
  52. model = get_model()
  53. wavs = []
  54. audio_lengths = []
  55. max_length = total_time = 0
  56. for file in files:
  57. wav, sr = torchaudio.load(file)
  58. if wav.shape[0] > 1:
  59. wav = wav.mean(dim=0, keepdim=True)
  60. wav = torchaudio.functional.resample(wav.cuda(), sr, model.sampling_rate)[0]
  61. wavs.append(wav)
  62. total_time += len(wav) / model.sampling_rate
  63. max_length = max(max_length, len(wav))
  64. audio_lengths.append(len(wav))
  65. # Pad to max length
  66. for i, wav in enumerate(wavs):
  67. wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
  68. audios = torch.stack(wavs, dim=0)[:, None]
  69. audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
  70. # Calculate lengths
  71. with torch.no_grad():
  72. # VQ Encoder
  73. features = gt_mels = model.mel_transform(
  74. audios, sample_rate=model.sampling_rate
  75. )
  76. if model.downsample is not None:
  77. features = model.downsample(features)
  78. feature_lengths = (
  79. audio_lengths
  80. / model.hop_length
  81. / (model.downsample.total_strides if model.downsample is not None else 1)
  82. ).long()
  83. feature_masks = torch.unsqueeze(
  84. sequence_mask(feature_lengths, features.shape[2]), 1
  85. ).to(gt_mels.dtype)
  86. # vq_features is 50 hz, need to convert to true mel size
  87. text_features = model.mel_encoder(features, feature_masks)
  88. _, indices, _ = model.vq_encoder(text_features, feature_masks)
  89. indices = indices.squeeze(-1)
  90. indices = rearrange(indices, "c b t -> b c t")
  91. # Save to disk
  92. outputs = indices.cpu().numpy()
  93. for file, length, feature, audio in zip(files, feature_lengths, outputs, audios):
  94. feature = feature[:, :length]
  95. # (T,)
  96. with open(file.with_suffix(".npy"), "wb") as f:
  97. np.save(f, feature)
  98. return total_time
  99. @click.command()
  100. @click.argument("folder")
  101. @click.option("--num-workers", default=1)
  102. def main(folder: str, num_workers: int):
  103. if num_workers > 1 and WORLD_SIZE != num_workers:
  104. assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
  105. logger.info(f"Spawning {num_workers} workers")
  106. visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
  107. if visible_devices is None:
  108. visible_devices = list(range(torch.cuda.device_count()))
  109. else:
  110. visible_devices = visible_devices.split(",")
  111. processes = []
  112. for i in range(num_workers):
  113. env = os.environ.copy()
  114. env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
  115. env["SLURM_PROCID"] = str(i)
  116. env["SLURM_NTASKS"] = str(num_workers)
  117. processes.append(
  118. sp.Popen(
  119. [sys.executable] + sys.argv.copy(),
  120. env=env,
  121. )
  122. )
  123. for p in processes:
  124. p.wait()
  125. logger.info(f"All workers finished")
  126. return
  127. # This is a worker
  128. logger.info(f"Starting worker")
  129. files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=True)
  130. Random(42).shuffle(files)
  131. total_files = len(files)
  132. files = files[RANK::WORLD_SIZE]
  133. logger.info(f"Processing {len(files)}/{total_files} files")
  134. # Batch size 64
  135. total_time = 0
  136. begin_time = time.time()
  137. processed_files = 0
  138. for n_batch, idx in enumerate(range(0, len(files), 32)):
  139. batch = files[idx : idx + 32]
  140. batch_time = process_batch(batch)
  141. total_time += batch_time
  142. processed_files += len(batch)
  143. if (n_batch + 1) % 10 == 0:
  144. eta = (
  145. (time.time() - begin_time)
  146. / processed_files
  147. * (len(files) - processed_files)
  148. )
  149. logger.info(
  150. f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
  151. + f"ETA: {timedelta(seconds=round(eta))}s"
  152. )
  153. logger.info(
  154. f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
  155. )
  156. if __name__ == "__main__":
  157. main()