Procházet zdrojové kódy

Add hubert feature export & finish network structure

Lengyue před 2 roky
rodič
revize
3420acdbbe

+ 252 - 9
fish_speech/models/hubert_vq/modules.py

@@ -1,10 +1,12 @@
 import math
+from dataclasses import dataclass
 
 import torch
 from torch import nn
 from torch.nn import Conv1d, Conv2d, ConvTranspose1d
 from torch.nn import functional as F
 from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
+from vector_quantize_pytorch import VectorQuantize
 
 from fish_speech.models.hubert_vq.utils import (
     convert_pad_shape,
@@ -15,17 +17,251 @@ from fish_speech.models.hubert_vq.utils import (
 LRELU_SLOPE = 0.1
 
 
+@dataclass
+class VQEncoderOutput:
+    loss: torch.Tensor
+    features: torch.Tensor
+
+
 class VQEncoder(nn.Module):
-    def __init__(self, *args, **kwargs) -> None:
-        super().__init__(*args, **kwargs)
+    def __init__(
+        self,
+        in_channels: int = 1024,
+        channels: int = 192,
+        num_heads: int = 2,
+        num_feature_layers: int = 2,
+        num_speaker_layers: int = 4,
+        num_mixin_layers: int = 4,
+        input_downsample: bool = True,
+        code_book_size: int = 2048,
+        freeze_vq: bool = False,
+    ):
+        super().__init__()
+
+        # Feature Encoder
+        down_sample = 2 if input_downsample else 1
+
+        self.vq_in = nn.Linear(in_channels * down_sample, in_channels)
+        self.vq = VectorQuantize(
+            dim=in_channels,
+            codebook_size=code_book_size,
+            threshold_ema_dead_code=2,
+        )
+
+        self.feature_in = nn.Linear(in_channels, channels)
+        self.feature_blocks = nn.ModuleList(
+            [
+                TransformerBlock(
+                    channels,
+                    num_heads,
+                    window_size=4,
+                    window_heads_share=True,
+                    proximal_init=True,
+                    proximal_bias=False,
+                    use_relative_attn=True,
+                )
+                for _ in range(num_feature_layers)
+            ]
+        )
 
-        encoder_layer = nn.TransformerEncoderLayer(
-            d_model=256, nhead=4, dim_feedforward=1024, dropout=0.1, activation="gelu"
+        # Speaker Encoder
+        self.speaker_query = nn.Parameter(torch.randn(1, 1, channels))
+        self.speaker_in = nn.Linear(in_channels * down_sample, channels)
+        self.speaker_blocks = nn.ModuleList(
+            [
+                TransformerBlock(
+                    channels,
+                    num_heads,
+                    use_relative_attn=False,
+                )
+                for _ in range(num_speaker_layers)
+            ]
         )
-        self.encoder = nn.TransformerEncoder(
-            encoder_layer, num_layers=6, norm=nn.LayerNorm(256)
+
+        # Final Mixer
+        self.mixer_in = nn.ModuleList(
+            [
+                TransformerBlock(
+                    channels,
+                    num_heads,
+                    window_size=4,
+                    window_heads_share=True,
+                    proximal_init=True,
+                    proximal_bias=False,
+                    use_relative_attn=True,
+                )
+                for _ in range(num_mixin_layers)
+            ]
         )
 
+        self.input_downsample = input_downsample
+
+        if freeze_vq:
+            for p in self.vq.parameters():
+                p.requires_grad = False
+
+            for p in self.vq_in.parameters():
+                p.requires_grad = False
+
+    def forward(self, x, key_padding_mask=None):
+        # (batch, seq_len, channels)
+
+        if self.input_downsample and key_padding_mask is not None:
+            key_padding_mask = key_padding_mask[:, ::2]
+
+        # Merge Channels
+        if self.input_downsample:
+            feature_0, feature_1 = x[:, ::2], x[:, 1::2]
+            min_len = min(feature_0.size(1), feature_1.size(1))
+            x = torch.cat([feature_0[:, :min_len], feature_1[:, :min_len]], dim=2)
+
+        # Encode Features
+        features = self.vq_in(x)
+        assert key_padding_mask.size(1) == features.size(
+            1
+        ), f"key_padding_mask shape {key_padding_mask.size()} is not (batch_size, seq_len)"
+
+        features, _, loss = self.vq(features, mask=~key_padding_mask)
+
+        features = self.feature_in(features)
+        for block in self.feature_blocks:
+            features = block(features, key_padding_mask=key_padding_mask)
+
+        # Encode Speaker
+        speaker = self.speaker_in(x)
+        speaker = torch.cat(
+            [self.speaker_query.expand(speaker.shape[0], -1, -1), speaker], dim=1
+        )
+        for block in self.speaker_blocks:
+            speaker = block(speaker)
+
+        # Mix
+        x = features + speaker[:, :1]
+        for block in self.mixer_in:
+            x = block(x, key_padding_mask=key_padding_mask)
+
+        return VQEncoderOutput(
+            loss=loss,
+            features=x.transpose(1, 2),
+        )
+
+
+class TransformerBlock(nn.Module):
+    def __init__(
+        self,
+        channels,
+        n_heads,
+        mlp_ratio=4 * 2 / 3,
+        p_dropout=0.0,
+        window_size=4,
+        window_heads_share=True,
+        proximal_init=True,
+        proximal_bias=False,
+        use_relative_attn=True,
+    ):
+        super().__init__()
+
+        self.attn_norm = RMSNorm(channels)
+
+        if use_relative_attn:
+            self.attn = RelativeAttention(
+                channels,
+                n_heads,
+                p_dropout,
+                window_size,
+                window_heads_share,
+                proximal_init,
+                proximal_bias,
+            )
+        else:
+            self.attn = nn.MultiheadAttention(
+                embed_dim=channels,
+                num_heads=n_heads,
+                dropout=p_dropout,
+                batch_first=True,
+            )
+
+        self.mlp_norm = RMSNorm(channels)
+        self.mlp = SwiGLU(channels, int(channels * mlp_ratio), channels, drop=p_dropout)
+
+    def forward(self, x, key_padding_mask=None):
+        norm_x = self.attn_norm(x)
+
+        if isinstance(self.attn, RelativeAttention):
+            attn = self.attn(norm_x, key_padding_mask=key_padding_mask)
+        else:
+            attn, _ = self.attn(
+                norm_x, norm_x, norm_x, key_padding_mask=key_padding_mask
+            )
+
+        x = x + attn
+        x = x + self.mlp(self.mlp_norm(x))
+
+        return x
+
+
+class SwiGLU(nn.Module):
+    """
+    Swish-Gated Linear Unit (SwiGLU) activation function
+    """
+
+    def __init__(
+        self,
+        in_features,
+        hidden_features=None,
+        out_features=None,
+        bias=True,
+        drop=0.0,
+    ):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        assert hidden_features % 2 == 0
+
+        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+        self.act = nn.SiLU()
+        self.drop1 = nn.Dropout(drop)
+        self.norm = RMSNorm(hidden_features // 2)
+        self.fc2 = nn.Linear(hidden_features // 2, out_features, bias=bias)
+        self.drop2 = nn.Dropout(drop)
+
+    def init_weights(self):
+        # override init of fc1 w/ gate portion set to weight near zero, bias=1
+        fc1_mid = self.fc1.bias.shape[0] // 2
+        nn.init.ones_(self.fc1.bias[fc1_mid:])
+        nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x1, x2 = x.chunk(2, dim=-1)
+
+        x = x1 * self.act(x2)
+        x = self.drop1(x)
+        x = self.norm(x)
+        x = self.fc2(x)
+        x = self.drop2(x)
+
+        return x
+
+
+class RMSNorm(nn.Module):
+    def __init__(self, hidden_size, eps=1e-6):
+        """
+        LlamaRMSNorm is equivalent to T5LayerNorm
+        """
+        super().__init__()
+
+        self.weight = nn.Parameter(torch.ones(hidden_size))
+        self.variance_epsilon = eps
+
+    def forward(self, hidden_states):
+        input_dtype = hidden_states.dtype
+        hidden_states = hidden_states.to(torch.float32)
+        variance = hidden_states.pow(2).mean(-1, keepdim=True)
+        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+        return self.weight * hidden_states.to(input_dtype)
+
 
 class RelativeAttention(nn.Module):
     def __init__(
@@ -117,11 +353,8 @@ class RelativeAttention(nn.Module):
             key_padding_mask = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(
                 -1, self.n_heads, -1, -1
             )
-            print(key_padding_mask.shape, scores.shape)
             scores = scores.masked_fill(key_padding_mask, float("-inf"))
 
-            print(scores[0, 0])
-
         p_attn = F.softmax(scores, dim=-1)  # [b, n_h, t_t, t_s]
         p_attn = self.drop(p_attn)
         output = torch.matmul(p_attn, value)
@@ -571,3 +804,13 @@ class EnsembleDiscriminator(nn.Module):
             fmap_gs.append(fmap_g)
 
         return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+if __name__ == "__main__":
+    vq = VQEncoder()
+    x = torch.randn(1, 90, 1024)
+    key_padding_mask = torch.zeros(1, 90).bool()
+    key_padding_mask[:, 67:] = True
+
+    output = vq(x, key_padding_mask=key_padding_mask)
+    print(output)

+ 174 - 0
tools/calculate_hubert_features.py

@@ -0,0 +1,174 @@
+# This file is used to convert the audio files to text files using the Whisper model.
+# It's mainly used to generate the training data for the VQ model.
+
+import os
+import subprocess as sp
+import sys
+import time
+from datetime import timedelta
+from functools import lru_cache
+from pathlib import Path
+from random import Random
+
+import click
+import numpy as np
+import torch
+import torchaudio
+from loguru import logger
+from transformers import HubertModel
+
+from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
+
+RANK = int(os.environ.get("SLURM_PROCID", 0))
+WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
+
+logger_format = (
+    "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
+    "<level>{level: <8}</level> | "
+    "<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
+    "{extra[rank]} - <level>{message}</level>"
+)
+logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
+logger.remove()
+logger.add(sys.stderr, format=logger_format)
+
+
+@lru_cache(maxsize=1)
+def get_hubert_model():
+    model = HubertModel.from_pretrained("TencentGameMate/chinese-hubert-large")
+    model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
+    model = model.half()
+    model.eval()
+
+    logger.info(f"Loaded model")
+    return model
+
+
+def process_batch(files: list[Path]):
+    model = get_hubert_model()
+
+    wavs = []
+    max_length = total_time = 0
+
+    for file in files:
+        wav, sr = torchaudio.load(file)
+        if wav.shape[0] > 1:
+            wav = wav.mean(dim=0, keepdim=True)
+
+        wav = torchaudio.functional.resample(wav.cuda(), sr, 16000)[0]
+
+        if len(wav) > sr * 60:
+            continue
+
+        wavs.append(wav)
+        total_time += len(wav) / sr
+        max_length = max(max_length, len(wav))
+
+    # Pad to max length
+    attention_mask = torch.ones(len(wavs), max_length, dtype=torch.float)
+    feature_lengths = []
+
+    if max_length % 320 != 0:
+        max_length += 320 - max_length % 320
+
+    for i, wav in enumerate(wavs):
+        attention_mask[i, len(wav) :] = 0
+        wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
+        feature_lengths.append(int(len(wav) / sr * 50))
+
+    wavs = torch.stack(wavs, dim=0).half()
+    attention_mask = attention_mask.cuda()
+
+    # Calculate lengths
+    with torch.no_grad():
+        outputs = model(wavs, attention_mask=attention_mask)
+
+    # Save to disk
+    outputs = outputs.last_hidden_state.cpu().numpy()
+
+    for file, length, feature in zip(files, feature_lengths, outputs):
+        feature = feature[:length]
+
+        # (T, 1024)
+        with open(file.with_suffix(".npy"), "wb") as f:
+            np.save(f, feature)
+
+    return total_time
+
+
+@click.command()
+@click.argument("folder")
+@click.option("--num-workers", default=1)
+def main(folder: str, num_workers: int):
+    if num_workers > 1 and WORLD_SIZE != num_workers:
+        assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
+
+        logger.info(f"Spawning {num_workers} workers")
+
+        visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
+        if visible_devices is None:
+            visible_devices = list(range(torch.cuda.device_count()))
+        else:
+            visible_devices = visible_devices.split(",")
+
+        processes = []
+        for i in range(num_workers):
+            env = os.environ.copy()
+            env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
+            env["SLURM_PROCID"] = str(i)
+            env["SLURM_NTASKS"] = str(num_workers)
+
+            processes.append(
+                sp.Popen(
+                    [sys.executable] + sys.argv.copy(),
+                    env=env,
+                )
+            )
+
+        for p in processes:
+            p.wait()
+
+        logger.info(f"All workers finished")
+        return
+
+    # This is a worker
+    logger.info(f"Starting worker")
+    files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=True)
+    Random(42).shuffle(files)
+
+    total_files = len(files)
+    files = files[RANK::WORLD_SIZE]
+    logger.info(f"Processing {len(files)}/{total_files} files")
+
+    # Batch size 64
+    total_time = 0
+    begin_time = time.time()
+    processed_files = 0
+
+    for n_batch, idx in enumerate(range(0, len(files), 64)):
+        batch = files[idx : idx + 64]
+        batch_time = process_batch(batch)
+        total_time += batch_time
+        processed_files += len(batch)
+
+        if (n_batch + 1) % 10 == 0:
+            eta = (
+                (time.time() - begin_time)
+                / processed_files
+                * (len(files) - processed_files)
+            )
+            logger.info(
+                f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, ETA: {timedelta(seconds=round(eta))}s"
+            )
+
+        # Stop after 1000 hours
+        if total_time * WORLD_SIZE > 3600 * 1000:
+            break
+
+    logger.info(
+        f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
+    )
+
+
+if __name__ == "__main__":
+    main()