Pārlūkot izejas kodu

Add hubert feature export & finish network structure

Lengyue 2 gadi atpakaļ
vecāks
revīzija
3420acdbbe

+ 252 - 9
fish_speech/models/hubert_vq/modules.py

@@ -1,10 +1,12 @@
 import math
 import math
+from dataclasses import dataclass
 
 
 import torch
 import torch
 from torch import nn
 from torch import nn
 from torch.nn import Conv1d, Conv2d, ConvTranspose1d
 from torch.nn import Conv1d, Conv2d, ConvTranspose1d
 from torch.nn import functional as F
 from torch.nn import functional as F
 from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
 from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
+from vector_quantize_pytorch import VectorQuantize
 
 
 from fish_speech.models.hubert_vq.utils import (
 from fish_speech.models.hubert_vq.utils import (
     convert_pad_shape,
     convert_pad_shape,
@@ -15,17 +17,251 @@ from fish_speech.models.hubert_vq.utils import (
 LRELU_SLOPE = 0.1
 LRELU_SLOPE = 0.1
 
 
 
 
+@dataclass
+class VQEncoderOutput:
+    loss: torch.Tensor
+    features: torch.Tensor
+
+
 class VQEncoder(nn.Module):
 class VQEncoder(nn.Module):
-    def __init__(self, *args, **kwargs) -> None:
-        super().__init__(*args, **kwargs)
+    def __init__(
+        self,
+        in_channels: int = 1024,
+        channels: int = 192,
+        num_heads: int = 2,
+        num_feature_layers: int = 2,
+        num_speaker_layers: int = 4,
+        num_mixin_layers: int = 4,
+        input_downsample: bool = True,
+        code_book_size: int = 2048,
+        freeze_vq: bool = False,
+    ):
+        super().__init__()
+
+        # Feature Encoder
+        down_sample = 2 if input_downsample else 1
+
+        self.vq_in = nn.Linear(in_channels * down_sample, in_channels)
+        self.vq = VectorQuantize(
+            dim=in_channels,
+            codebook_size=code_book_size,
+            threshold_ema_dead_code=2,
+        )
+
+        self.feature_in = nn.Linear(in_channels, channels)
+        self.feature_blocks = nn.ModuleList(
+            [
+                TransformerBlock(
+                    channels,
+                    num_heads,
+                    window_size=4,
+                    window_heads_share=True,
+                    proximal_init=True,
+                    proximal_bias=False,
+                    use_relative_attn=True,
+                )
+                for _ in range(num_feature_layers)
+            ]
+        )
 
 
-        encoder_layer = nn.TransformerEncoderLayer(
-            d_model=256, nhead=4, dim_feedforward=1024, dropout=0.1, activation="gelu"
+        # Speaker Encoder
+        self.speaker_query = nn.Parameter(torch.randn(1, 1, channels))
+        self.speaker_in = nn.Linear(in_channels * down_sample, channels)
+        self.speaker_blocks = nn.ModuleList(
+            [
+                TransformerBlock(
+                    channels,
+                    num_heads,
+                    use_relative_attn=False,
+                )
+                for _ in range(num_speaker_layers)
+            ]
         )
         )
-        self.encoder = nn.TransformerEncoder(
-            encoder_layer, num_layers=6, norm=nn.LayerNorm(256)
+
+        # Final Mixer
+        self.mixer_in = nn.ModuleList(
+            [
+                TransformerBlock(
+                    channels,
+                    num_heads,
+                    window_size=4,
+                    window_heads_share=True,
+                    proximal_init=True,
+                    proximal_bias=False,
+                    use_relative_attn=True,
+                )
+                for _ in range(num_mixin_layers)
+            ]
         )
         )
 
 
+        self.input_downsample = input_downsample
+
+        if freeze_vq:
+            for p in self.vq.parameters():
+                p.requires_grad = False
+
+            for p in self.vq_in.parameters():
+                p.requires_grad = False
+
+    def forward(self, x, key_padding_mask=None):
+        # (batch, seq_len, channels)
+
+        if self.input_downsample and key_padding_mask is not None:
+            key_padding_mask = key_padding_mask[:, ::2]
+
+        # Merge Channels
+        if self.input_downsample:
+            feature_0, feature_1 = x[:, ::2], x[:, 1::2]
+            min_len = min(feature_0.size(1), feature_1.size(1))
+            x = torch.cat([feature_0[:, :min_len], feature_1[:, :min_len]], dim=2)
+
+        # Encode Features
+        features = self.vq_in(x)
+        assert key_padding_mask.size(1) == features.size(
+            1
+        ), f"key_padding_mask shape {key_padding_mask.size()} is not (batch_size, seq_len)"
+
+        features, _, loss = self.vq(features, mask=~key_padding_mask)
+
+        features = self.feature_in(features)
+        for block in self.feature_blocks:
+            features = block(features, key_padding_mask=key_padding_mask)
+
+        # Encode Speaker
+        speaker = self.speaker_in(x)
+        speaker = torch.cat(
+            [self.speaker_query.expand(speaker.shape[0], -1, -1), speaker], dim=1
+        )
+        for block in self.speaker_blocks:
+            speaker = block(speaker)
+
+        # Mix
+        x = features + speaker[:, :1]
+        for block in self.mixer_in:
+            x = block(x, key_padding_mask=key_padding_mask)
+
+        return VQEncoderOutput(
+            loss=loss,
+            features=x.transpose(1, 2),
+        )
+
+
+class TransformerBlock(nn.Module):
+    def __init__(
+        self,
+        channels,
+        n_heads,
+        mlp_ratio=4 * 2 / 3,
+        p_dropout=0.0,
+        window_size=4,
+        window_heads_share=True,
+        proximal_init=True,
+        proximal_bias=False,
+        use_relative_attn=True,
+    ):
+        super().__init__()
+
+        self.attn_norm = RMSNorm(channels)
+
+        if use_relative_attn:
+            self.attn = RelativeAttention(
+                channels,
+                n_heads,
+                p_dropout,
+                window_size,
+                window_heads_share,
+                proximal_init,
+                proximal_bias,
+            )
+        else:
+            self.attn = nn.MultiheadAttention(
+                embed_dim=channels,
+                num_heads=n_heads,
+                dropout=p_dropout,
+                batch_first=True,
+            )
+
+        self.mlp_norm = RMSNorm(channels)
+        self.mlp = SwiGLU(channels, int(channels * mlp_ratio), channels, drop=p_dropout)
+
+    def forward(self, x, key_padding_mask=None):
+        norm_x = self.attn_norm(x)
+
+        if isinstance(self.attn, RelativeAttention):
+            attn = self.attn(norm_x, key_padding_mask=key_padding_mask)
+        else:
+            attn, _ = self.attn(
+                norm_x, norm_x, norm_x, key_padding_mask=key_padding_mask
+            )
+
+        x = x + attn
+        x = x + self.mlp(self.mlp_norm(x))
+
+        return x
+
+
+class SwiGLU(nn.Module):
+    """
+    Swish-Gated Linear Unit (SwiGLU) activation function
+    """
+
+    def __init__(
+        self,
+        in_features,
+        hidden_features=None,
+        out_features=None,
+        bias=True,
+        drop=0.0,
+    ):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        assert hidden_features % 2 == 0
+
+        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+        self.act = nn.SiLU()
+        self.drop1 = nn.Dropout(drop)
+        self.norm = RMSNorm(hidden_features // 2)
+        self.fc2 = nn.Linear(hidden_features // 2, out_features, bias=bias)
+        self.drop2 = nn.Dropout(drop)
+
+    def init_weights(self):
+        # override init of fc1 w/ gate portion set to weight near zero, bias=1
+        fc1_mid = self.fc1.bias.shape[0] // 2
+        nn.init.ones_(self.fc1.bias[fc1_mid:])
+        nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x1, x2 = x.chunk(2, dim=-1)
+
+        x = x1 * self.act(x2)
+        x = self.drop1(x)
+        x = self.norm(x)
+        x = self.fc2(x)
+        x = self.drop2(x)
+
+        return x
+
+
+class RMSNorm(nn.Module):
+    def __init__(self, hidden_size, eps=1e-6):
+        """
+        LlamaRMSNorm is equivalent to T5LayerNorm
+        """
+        super().__init__()
+
+        self.weight = nn.Parameter(torch.ones(hidden_size))
+        self.variance_epsilon = eps
+
+    def forward(self, hidden_states):
+        input_dtype = hidden_states.dtype
+        hidden_states = hidden_states.to(torch.float32)
+        variance = hidden_states.pow(2).mean(-1, keepdim=True)
+        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+        return self.weight * hidden_states.to(input_dtype)
+
 
 
 class RelativeAttention(nn.Module):
 class RelativeAttention(nn.Module):
     def __init__(
     def __init__(
@@ -117,11 +353,8 @@ class RelativeAttention(nn.Module):
             key_padding_mask = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(
             key_padding_mask = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(
                 -1, self.n_heads, -1, -1
                 -1, self.n_heads, -1, -1
             )
             )
-            print(key_padding_mask.shape, scores.shape)
             scores = scores.masked_fill(key_padding_mask, float("-inf"))
             scores = scores.masked_fill(key_padding_mask, float("-inf"))
 
 
-            print(scores[0, 0])
-
         p_attn = F.softmax(scores, dim=-1)  # [b, n_h, t_t, t_s]
         p_attn = F.softmax(scores, dim=-1)  # [b, n_h, t_t, t_s]
         p_attn = self.drop(p_attn)
         p_attn = self.drop(p_attn)
         output = torch.matmul(p_attn, value)
         output = torch.matmul(p_attn, value)
@@ -571,3 +804,13 @@ class EnsembleDiscriminator(nn.Module):
             fmap_gs.append(fmap_g)
             fmap_gs.append(fmap_g)
 
 
         return y_d_rs, y_d_gs, fmap_rs, fmap_gs
         return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+if __name__ == "__main__":
+    vq = VQEncoder()
+    x = torch.randn(1, 90, 1024)
+    key_padding_mask = torch.zeros(1, 90).bool()
+    key_padding_mask[:, 67:] = True
+
+    output = vq(x, key_padding_mask=key_padding_mask)
+    print(output)

+ 174 - 0
tools/calculate_hubert_features.py

@@ -0,0 +1,174 @@
+# This file is used to convert the audio files to text files using the Whisper model.
+# It's mainly used to generate the training data for the VQ model.
+
+import os
+import subprocess as sp
+import sys
+import time
+from datetime import timedelta
+from functools import lru_cache
+from pathlib import Path
+from random import Random
+
+import click
+import numpy as np
+import torch
+import torchaudio
+from loguru import logger
+from transformers import HubertModel
+
+from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
+
+RANK = int(os.environ.get("SLURM_PROCID", 0))
+WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
+
+logger_format = (
+    "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
+    "<level>{level: <8}</level> | "
+    "<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
+    "{extra[rank]} - <level>{message}</level>"
+)
+logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
+logger.remove()
+logger.add(sys.stderr, format=logger_format)
+
+
+@lru_cache(maxsize=1)
+def get_hubert_model():
+    model = HubertModel.from_pretrained("TencentGameMate/chinese-hubert-large")
+    model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
+    model = model.half()
+    model.eval()
+
+    logger.info(f"Loaded model")
+    return model
+
+
+def process_batch(files: list[Path]):
+    model = get_hubert_model()
+
+    wavs = []
+    max_length = total_time = 0
+
+    for file in files:
+        wav, sr = torchaudio.load(file)
+        if wav.shape[0] > 1:
+            wav = wav.mean(dim=0, keepdim=True)
+
+        wav = torchaudio.functional.resample(wav.cuda(), sr, 16000)[0]
+
+        if len(wav) > sr * 60:
+            continue
+
+        wavs.append(wav)
+        total_time += len(wav) / sr
+        max_length = max(max_length, len(wav))
+
+    # Pad to max length
+    attention_mask = torch.ones(len(wavs), max_length, dtype=torch.float)
+    feature_lengths = []
+
+    if max_length % 320 != 0:
+        max_length += 320 - max_length % 320
+
+    for i, wav in enumerate(wavs):
+        attention_mask[i, len(wav) :] = 0
+        wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
+        feature_lengths.append(int(len(wav) / sr * 50))
+
+    wavs = torch.stack(wavs, dim=0).half()
+    attention_mask = attention_mask.cuda()
+
+    # Calculate lengths
+    with torch.no_grad():
+        outputs = model(wavs, attention_mask=attention_mask)
+
+    # Save to disk
+    outputs = outputs.last_hidden_state.cpu().numpy()
+
+    for file, length, feature in zip(files, feature_lengths, outputs):
+        feature = feature[:length]
+
+        # (T, 1024)
+        with open(file.with_suffix(".npy"), "wb") as f:
+            np.save(f, feature)
+
+    return total_time
+
+
+@click.command()
+@click.argument("folder")
+@click.option("--num-workers", default=1)
+def main(folder: str, num_workers: int):
+    if num_workers > 1 and WORLD_SIZE != num_workers:
+        assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
+
+        logger.info(f"Spawning {num_workers} workers")
+
+        visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
+        if visible_devices is None:
+            visible_devices = list(range(torch.cuda.device_count()))
+        else:
+            visible_devices = visible_devices.split(",")
+
+        processes = []
+        for i in range(num_workers):
+            env = os.environ.copy()
+            env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
+            env["SLURM_PROCID"] = str(i)
+            env["SLURM_NTASKS"] = str(num_workers)
+
+            processes.append(
+                sp.Popen(
+                    [sys.executable] + sys.argv.copy(),
+                    env=env,
+                )
+            )
+
+        for p in processes:
+            p.wait()
+
+        logger.info(f"All workers finished")
+        return
+
+    # This is a worker
+    logger.info(f"Starting worker")
+    files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=True)
+    Random(42).shuffle(files)
+
+    total_files = len(files)
+    files = files[RANK::WORLD_SIZE]
+    logger.info(f"Processing {len(files)}/{total_files} files")
+
+    # Batch size 64
+    total_time = 0
+    begin_time = time.time()
+    processed_files = 0
+
+    for n_batch, idx in enumerate(range(0, len(files), 64)):
+        batch = files[idx : idx + 64]
+        batch_time = process_batch(batch)
+        total_time += batch_time
+        processed_files += len(batch)
+
+        if (n_batch + 1) % 10 == 0:
+            eta = (
+                (time.time() - begin_time)
+                / processed_files
+                * (len(files) - processed_files)
+            )
+            logger.info(
+                f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, ETA: {timedelta(seconds=round(eta))}s"
+            )
+
+        # Stop after 1000 hours
+        if total_time * WORLD_SIZE > 3600 * 1000:
+            break
+
+    logger.info(
+        f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
+    )
+
+
+if __name__ == "__main__":
+    main()