Переглянути джерело

Add data splitter & GPU kmeans

Lengyue 2 роки тому
батько
коміт
7f4b9631ab
2 змінених файлів з 131 додано та 0 видалено
  1. 100 0
      tools/vqgan/calculate_kmeans_init.py
  2. 31 0
      tools/vqgan/create_train_split.py

+ 100 - 0
tools/vqgan/calculate_kmeans_init.py

@@ -0,0 +1,100 @@
+from pathlib import Path
+
+import click
+import numpy as np
+import torch
+from einops import rearrange, repeat
+from torch.utils.data import DataLoader, Dataset
+from tqdm import tqdm
+from vector_quantize_pytorch.vector_quantize_pytorch import (
+    batched_bincount,
+    batched_sample_vectors,
+    cdist,
+)
+
+
+class KMeansDataset(Dataset):
+    def __init__(self, filelist):
+        root = Path(filelist).parent
+
+        with open(filelist) as f:
+            self.files = f.read().splitlines()
+
+        self.files = [root / file.strip() for file in self.files]
+
+    def __len__(self):
+        return len(self.files)
+
+    def __getitem__(self, idx):
+        file = self.files[idx]
+        feature = np.load(file.with_suffix(".npy"))
+        return torch.from_numpy(feature).float()
+
+    @staticmethod
+    def collate_fn(features):
+        features = torch.concat(features, dim=0)
+        return features
+
+
+@click.command()
+@click.option(
+    "--filelist",
+    type=click.Path(exists=True, path_type=Path),
+    default="data/test.filelist",
+)
+@click.option("--output", type=click.Path(path_type=Path), default="kmeans.pt")
+@click.option("--num-clusters", type=int, default=2048)
+@click.option("--epochs", type=int, default=10)
+def main(filelist: Path, output: Path, num_clusters: int, epochs: int):
+    loader = DataLoader(
+        KMeansDataset(filelist),
+        batch_size=1024,
+        shuffle=True,
+        num_workers=2,
+        collate_fn=KMeansDataset.collate_fn,
+    )
+
+    means = None
+    for _ in tqdm(range(epochs), desc="Epochs", position=0):
+        total_bins = torch.zeros(1, num_clusters, dtype=torch.int64, device="cuda")
+
+        for samples in tqdm(loader, desc="Batches", position=1):
+            samples = samples.cuda()[None]
+            num_codebooks, dim, dtype = (
+                samples.shape[0],
+                samples.shape[-1],
+                samples.dtype,
+            )
+
+            if means is None:
+                means = batched_sample_vectors(samples, num_clusters)
+
+            dists = -cdist(samples, means)
+
+            buckets = torch.argmax(dists, dim=-1)
+            bins = batched_bincount(buckets, minlength=num_clusters)
+
+            zero_mask = bins == 0
+            bins_min_clamped = bins.masked_fill(zero_mask, 1)
+
+            new_means = buckets.new_zeros(num_codebooks, num_clusters, dim, dtype=dtype)
+
+            new_means.scatter_add_(1, repeat(buckets, "h n -> h n d", d=dim), samples)
+            new_means = new_means / rearrange(bins_min_clamped, "... -> ... 1")
+
+            means = torch.where(rearrange(zero_mask, "... -> ... 1"), means, new_means)
+
+            total_bins += bins
+
+    torch.save(
+        {
+            "centroids": means,
+            "bins": bins,
+        },
+        output,
+    )
+    print(f"Saved to {output}")
+
+
+if __name__ == "__main__":
+    main()

+ 31 - 0
tools/vqgan/create_train_split.py

@@ -0,0 +1,31 @@
+from pathlib import Path
+from random import Random
+
+import click
+from tqdm import tqdm
+
+from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
+
+
+@click.command()
+@click.argument("root", type=click.Path(exists=True, path_type=Path))
+def main(root):
+    files = list_files(root, AUDIO_EXTENSIONS, recursive=True, show_progress=True)
+    print(f"Found {len(files)} files")
+
+    files = [str(file) for file in tqdm(files) if file.with_suffix(".npy").exists()]
+    print(f"Found {len(files)} files with features")
+
+    Random(42).shuffle(files)
+
+    with open(root / "vq_train_filelist.txt", "w") as f:
+        f.write("\n".join(files[:-100]))
+
+    with open(root / "vq_val_filelist.txt", "w") as f:
+        f.write("\n".join(files[-100:]))
+
+    print("Done")
+
+
+if __name__ == "__main__":
+    main()