calculate_hubert_features.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. # This file is used to convert the audio files to text files using the Whisper model.
  2. # It's mainly used to generate the training data for the VQ model.
  3. import os
  4. import subprocess as sp
  5. import sys
  6. import time
  7. from datetime import timedelta
  8. from functools import lru_cache
  9. from pathlib import Path
  10. from random import Random
  11. import click
  12. import numpy as np
  13. import torch
  14. import torchaudio
  15. from loguru import logger
  16. from transformers import HubertModel
  17. from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
  18. RANK = int(os.environ.get("SLURM_PROCID", 0))
  19. WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
  20. logger_format = (
  21. "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
  22. "<level>{level: <8}</level> | "
  23. "<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
  24. "{extra[rank]} - <level>{message}</level>"
  25. )
  26. logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
  27. logger.remove()
  28. logger.add(sys.stderr, format=logger_format)
  29. @lru_cache(maxsize=1)
  30. def get_hubert_model():
  31. model = HubertModel.from_pretrained("TencentGameMate/chinese-hubert-large")
  32. model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
  33. model = model.half()
  34. model.eval()
  35. logger.info(f"Loaded model")
  36. return model
  37. def process_batch(files: list[Path], kmeans_centers: torch.Tensor) -> float:
  38. model = get_hubert_model()
  39. wavs = []
  40. max_length = total_time = 0
  41. for file in files:
  42. wav, sr = torchaudio.load(file)
  43. if wav.shape[0] > 1:
  44. wav = wav.mean(dim=0, keepdim=True)
  45. wav = torchaudio.functional.resample(wav.cuda(), sr, 16000)[0]
  46. if len(wav) > sr * 60:
  47. wav = wav[: sr * 60]
  48. wavs.append(wav)
  49. total_time += len(wav) / sr
  50. max_length = max(max_length, len(wav))
  51. # Pad to max length
  52. attention_mask = torch.ones(len(wavs), max_length, dtype=torch.float)
  53. feature_lengths = []
  54. if max_length % 320 != 0:
  55. max_length += 320 - max_length % 320
  56. for i, wav in enumerate(wavs):
  57. attention_mask[i, len(wav) :] = 0
  58. feature_lengths.append(int(len(wav) / 320))
  59. wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
  60. wavs = torch.stack(wavs, dim=0).half()
  61. attention_mask = attention_mask.cuda()
  62. # Calculate lengths
  63. with torch.no_grad():
  64. outputs = model(wavs, attention_mask=attention_mask).last_hidden_state
  65. # Find closest centroids
  66. kmeans_centers = kmeans_centers.to(dtype=outputs.dtype, device=outputs.device)
  67. distances = torch.cdist(outputs, kmeans_centers)
  68. outputs = torch.min(distances, dim=-1)
  69. avg_distance = torch.mean(outputs.values)
  70. # Save to disk
  71. outputs = outputs.indices.cpu().numpy()
  72. for file, length, feature, wav in zip(files, feature_lengths, outputs, wavs):
  73. feature = feature[:length]
  74. # (T,)
  75. with open(file.with_suffix(".npy"), "wb") as f:
  76. np.save(f, feature)
  77. return total_time, avg_distance
  78. @click.command()
  79. @click.argument("folder")
  80. @click.option("--num-workers", default=1)
  81. @click.option("--kmeans", default="results/hubert-vq-pretrain/kmeans.pt")
  82. def main(folder: str, num_workers: int, kmeans: str):
  83. if num_workers > 1 and WORLD_SIZE != num_workers:
  84. assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
  85. logger.info(f"Spawning {num_workers} workers")
  86. visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
  87. if visible_devices is None:
  88. visible_devices = list(range(torch.cuda.device_count()))
  89. else:
  90. visible_devices = visible_devices.split(",")
  91. processes = []
  92. for i in range(num_workers):
  93. env = os.environ.copy()
  94. env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
  95. env["SLURM_PROCID"] = str(i)
  96. env["SLURM_NTASKS"] = str(num_workers)
  97. processes.append(
  98. sp.Popen(
  99. [sys.executable] + sys.argv.copy(),
  100. env=env,
  101. )
  102. )
  103. for p in processes:
  104. p.wait()
  105. logger.info(f"All workers finished")
  106. return
  107. # This is a worker
  108. logger.info(f"Starting worker")
  109. files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=True)
  110. Random(42).shuffle(files)
  111. total_files = len(files)
  112. files = files[RANK::WORLD_SIZE]
  113. logger.info(f"Processing {len(files)}/{total_files} files")
  114. # Load kmeans
  115. kmeans_centers = torch.load(kmeans)["centroids"]
  116. # Batch size 64
  117. total_time = 0
  118. begin_time = time.time()
  119. processed_files = 0
  120. total_distance = 0
  121. for n_batch, idx in enumerate(range(0, len(files), 32)):
  122. batch = files[idx : idx + 32]
  123. batch_time, avg_distance = process_batch(batch, kmeans_centers)
  124. total_time += batch_time
  125. processed_files += len(batch)
  126. total_distance += avg_distance
  127. if (n_batch + 1) % 10 == 0:
  128. eta = (
  129. (time.time() - begin_time)
  130. / processed_files
  131. * (len(files) - processed_files)
  132. )
  133. logger.info(
  134. f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
  135. + f"err {total_distance/(n_batch+1):.2f}, ETA: {timedelta(seconds=round(eta))}s"
  136. )
  137. logger.info(
  138. f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
  139. )
  140. if __name__ == "__main__":
  141. main()