extract_vq.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. import os
  2. import subprocess as sp
  3. import sys
  4. import time
  5. from datetime import timedelta
  6. from functools import lru_cache
  7. from pathlib import Path
  8. from random import Random
  9. import click
  10. import numpy as np
  11. import torch
  12. import torchaudio
  13. from einops import rearrange
  14. from hydra import compose, initialize
  15. from hydra.utils import instantiate
  16. from lightning import LightningModule
  17. from loguru import logger
  18. from omegaconf import OmegaConf
  19. from fish_speech.models.vqgan.utils import sequence_mask
  20. from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
  21. # register eval resolver
  22. OmegaConf.register_new_resolver("eval", eval)
  23. # This file is used to convert the audio files to text files using the Whisper model.
  24. # It's mainly used to generate the training data for the VQ model.
  25. RANK = int(os.environ.get("SLURM_PROCID", 0))
  26. WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
  27. logger_format = (
  28. "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
  29. "<level>{level: <8}</level> | "
  30. "<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
  31. "{extra[rank]} - <level>{message}</level>"
  32. )
  33. logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
  34. logger.remove()
  35. logger.add(sys.stderr, format=logger_format)
  36. @lru_cache(maxsize=1)
  37. def get_model(
  38. config_name: str = "vqgan",
  39. checkpoint_path: str = "checkpoints/vqgan/step_000380000.ckpt",
  40. ):
  41. with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
  42. cfg = compose(config_name=config_name)
  43. model: LightningModule = instantiate(cfg.model)
  44. state_dict = torch.load(
  45. checkpoint_path,
  46. map_location=model.device,
  47. )
  48. if "state_dict" in state_dict:
  49. state_dict = state_dict["state_dict"]
  50. model.load_state_dict(state_dict, strict=True)
  51. model.eval()
  52. model.cuda()
  53. logger.info(f"Loaded model")
  54. return model
  55. def process_batch(files: list[Path], model) -> float:
  56. wavs = []
  57. audio_lengths = []
  58. max_length = total_time = 0
  59. for file in files:
  60. wav, sr = torchaudio.load(file)
  61. if wav.shape[0] > 1:
  62. wav = wav.mean(dim=0, keepdim=True)
  63. wav = torchaudio.functional.resample(wav.cuda(), sr, model.sampling_rate)[0]
  64. wavs.append(wav)
  65. total_time += len(wav) / model.sampling_rate
  66. max_length = max(max_length, len(wav))
  67. audio_lengths.append(len(wav))
  68. # Pad to max length
  69. for i, wav in enumerate(wavs):
  70. wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
  71. audios = torch.stack(wavs, dim=0)[:, None]
  72. audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
  73. # Calculate lengths
  74. with torch.no_grad():
  75. # VQ Encoder
  76. features = gt_mels = model.mel_transform(
  77. audios, sample_rate=model.sampling_rate
  78. )
  79. if model.downsample is not None:
  80. features = model.downsample(features)
  81. feature_lengths = (
  82. audio_lengths
  83. / model.hop_length
  84. / (model.downsample.total_strides if model.downsample is not None else 1)
  85. ).long()
  86. feature_masks = torch.unsqueeze(
  87. sequence_mask(feature_lengths, features.shape[2]), 1
  88. ).to(gt_mels.dtype)
  89. text_features = model.mel_encoder(features, feature_masks)
  90. _, indices, _ = model.vq_encoder(text_features, feature_masks)
  91. if indices.ndim == 4:
  92. # Grouped vq
  93. assert indices.shape[-1] == 1, f"Residual vq is not supported"
  94. indices = indices.squeeze(-1)
  95. elif indices.ndim == 2:
  96. # Single vq
  97. indices = indices.unsqueeze(0)
  98. else:
  99. raise ValueError(f"Invalid indices shape {indices.shape}")
  100. indices = rearrange(indices, "c b t -> b c t")
  101. # Save to disk
  102. outputs = indices.cpu().numpy()
  103. for file, length, feature, audio in zip(files, feature_lengths, outputs, audios):
  104. feature = feature[:, :length]
  105. # (T,)
  106. with open(file.with_suffix(".npy"), "wb") as f:
  107. np.save(f, feature)
  108. return total_time
  109. @click.command()
  110. @click.argument("folder")
  111. @click.option("--num-workers", default=1)
  112. @click.option("--config-name", default="vqgan_pretrain")
  113. @click.option(
  114. "--checkpoint-path",
  115. default="checkpoints/vqgan-v1.pth",
  116. )
  117. @click.option("--batch-size", default=64)
  118. @click.option("--filelist", default=None, type=Path)
  119. def main(
  120. folder: str,
  121. num_workers: int,
  122. config_name: str,
  123. checkpoint_path: str,
  124. batch_size: int,
  125. filelist: Path,
  126. ):
  127. if num_workers > 1 and WORLD_SIZE != num_workers:
  128. assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
  129. logger.info(f"Spawning {num_workers} workers")
  130. visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
  131. if visible_devices is None:
  132. visible_devices = list(range(torch.cuda.device_count()))
  133. else:
  134. visible_devices = visible_devices.split(",")
  135. processes = []
  136. for i in range(num_workers):
  137. env = os.environ.copy()
  138. env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
  139. env["SLURM_PROCID"] = str(i)
  140. env["SLURM_NTASKS"] = str(num_workers)
  141. processes.append(
  142. sp.Popen(
  143. [sys.executable] + sys.argv.copy(),
  144. env=env,
  145. )
  146. )
  147. for p in processes:
  148. p.wait()
  149. logger.info(f"All workers finished")
  150. return
  151. # This is a worker
  152. logger.info(f"Starting worker")
  153. if filelist:
  154. with open(filelist, "r", encoding="utf-8") as f:
  155. # files = [Path(line..strip().split("|")[0]) for line in f]
  156. files = set()
  157. countSame = 0
  158. countNotFound = 0
  159. for line in f.readlines():
  160. file = Path(line.strip().split("|")[0])
  161. if file in files:
  162. print(f"重复音频文本:{line}")
  163. countSame += 1
  164. continue
  165. if not os.path.isfile(file):
  166. # 过滤数据集错误:不存在对应音频
  167. print(f"没有找到对应的音频:{file}")
  168. countNotFound += 1
  169. continue
  170. files.add(file)
  171. files = list(files)
  172. print(f"总重复音频数:{countSame},总未找到的音频数:{countNotFound}")
  173. else:
  174. files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=True)
  175. Random(42).shuffle(files)
  176. total_files = len(files)
  177. files = files[RANK::WORLD_SIZE]
  178. logger.info(f"Processing {len(files)}/{total_files} files")
  179. # Batch processing
  180. total_time = 0
  181. begin_time = time.time()
  182. processed_files = 0
  183. model = get_model(config_name, checkpoint_path)
  184. for n_batch, idx in enumerate(range(0, len(files), batch_size)):
  185. batch = files[idx : idx + batch_size]
  186. batch_time = process_batch(batch, model)
  187. total_time += batch_time
  188. processed_files += len(batch)
  189. if (n_batch + 1) % 10 == 0:
  190. eta = (
  191. (time.time() - begin_time)
  192. / processed_files
  193. * (len(files) - processed_files)
  194. )
  195. logger.info(
  196. f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
  197. + f"ETA: {timedelta(seconds=round(eta))}s"
  198. )
  199. logger.info(
  200. f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
  201. )
  202. if __name__ == "__main__":
  203. main()