Jelajahi Sumber

Update package info & hubert code

Lengyue 2 tahun lalu
induk
melakukan
c6126b4d7c

+ 8 - 9
README.md

@@ -1,18 +1,17 @@
-# Speech LLM
+# Fish Speech
+
+This repo is still under construction. Please check back later.
 
 ## Setup
 ```bash
 # Basic environment setup
-conda create -n speech-llm python=3.10
-conda activate speech-llm
+conda create -n fish-speech python=3.10
+conda activate fish-speech
 conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
 
-# Install requirements
-pip3 install -r requirements.txt
-
-# Install flash-attn
-MAX_JOBS=4 pip install flash-attn --no-build-isolation
+# Install flash-attn (for linux)
+pip3 install ninja && MAX_JOBS=4 pip3 install flash-attn --no-build-isolation
 
-# Install speech-llm
+# Install fish-speech
 pip3 install -e .
 ```

+ 3 - 13
fish_speech/configs/llama_finetune.yaml

@@ -18,19 +18,9 @@ tokenizer:
 
 # Dataset Configuration
 train_dataset:
-  _target_: fish_speech.datasets.text.InterleaveDataset
-  datasets:
-    - _target_: fish_speech.datasets.text.TextDataset
-      prefix: 'en/'
-    - _target_: fish_speech.datasets.text.TextDataset
-      prefix: 'zh/'
-    - _target_: fish_speech.datasets.text.TextDataset
-      prefix: 'ja/'
-    - _target_: fish_speech.datasets.text.TextDataset
-      repo: fishaudio/cn-hubert-25hz-vq
-      prefix: 'data/train'
-  probabilities: [0.2, 0.2, 0.2, 0.4]
-  seed: 42
+  - _target_: fish_speech.datasets.text.TextDataset
+    repo: fishaudio/cn-hubert-25hz-vq
+    prefix: 'data/train'
 
 val_dataset:
   _target_: fish_speech.datasets.text.TextDataset

+ 347 - 0
fish_speech/models/hubert_vq/lit_module.py

@@ -0,0 +1,347 @@
+from typing import Any, Callable
+
+import torch
+import torch.nn.functional as F
+from fish_vocoder.models.vocoder import VocoderModel
+from fish_vocoder.modules.losses.stft import MultiResolutionSTFTLoss
+from fish_vocoder.utils.grad_norm import grad_norm
+from fish_vocoder.utils.mask import sequence_mask
+from torch import nn
+from torch.utils.checkpoint import checkpoint as gradient_checkpointing
+
+
+class GANModel(VocoderModel):
+    def __init__(
+        self,
+        sampling_rate: int,
+        n_fft: int,
+        hop_length: int,
+        win_length: int,
+        num_mels: int,
+        optimizer: Callable,
+        lr_scheduler: Callable,
+        mel_transforms: nn.ModuleDict,
+        generator: nn.Module,
+        discriminators: nn.ModuleDict,
+        multi_resolution_stft_loss: MultiResolutionSTFTLoss,
+        num_frames: int,
+        crop_length: int | None = None,
+        checkpointing: bool = False,
+        feature_matching: bool = False,
+    ):
+        super().__init__(
+            sampling_rate=sampling_rate,
+            n_fft=n_fft,
+            hop_length=hop_length,
+            win_length=win_length,
+            num_mels=num_mels,
+        )
+
+        # Model parameters
+        self.optimizer_builder = optimizer
+        self.lr_scheduler_builder = lr_scheduler
+
+        # Spectrogram transforms
+        self.mel_transforms = mel_transforms
+
+        # Generator and discriminators
+        # Compile generator so that snake can save memory
+        self.generator = generator
+        self.discriminators = discriminators
+
+        # Loss
+        self.multi_resolution_stft_loss = multi_resolution_stft_loss
+
+        # Crop length for saving memory
+        self.num_frames = num_frames
+        self.crop_length = crop_length
+
+        # Disable automatic optimization
+        self.automatic_optimization = False
+
+        # Gradient checkpointing
+        self.checkpointing = checkpointing
+
+        # Feature matching
+        self.feature_matching = feature_matching
+
+    def configure_optimizers(self):
+        # Need two optimizers and two schedulers
+        optimizer_generator = self.optimizer_builder(self.generator.parameters())
+        optimizer_discriminator = self.optimizer_builder(
+            self.discriminators.parameters()
+        )
+
+        lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
+        lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
+
+        return (
+            {
+                "optimizer": optimizer_generator,
+                "lr_scheduler": {
+                    "scheduler": lr_scheduler_generator,
+                    "interval": "step",
+                    "name": "optimizer/generator",
+                },
+            },
+            {
+                "optimizer": optimizer_discriminator,
+                "lr_scheduler": {
+                    "scheduler": lr_scheduler_discriminator,
+                    "interval": "step",
+                    "name": "optimizer/discriminator",
+                },
+            },
+        )
+
+    def training_generator(self, audio, audio_mask):
+        if self.training and self.checkpointing:
+            fake_audio, base_loss = gradient_checkpointing(
+                self.forward, audio, audio_mask, use_reentrant=False
+            )
+        else:
+            fake_audio, base_loss = self.forward(audio, audio_mask)
+
+        assert fake_audio.shape == audio.shape
+
+        # Apply mask
+        audio = audio * audio_mask
+        fake_audio = fake_audio * audio_mask
+
+        # Multi-Resolution STFT Loss
+        sc_loss, mag_loss = self.multi_resolution_stft_loss(
+            fake_audio.squeeze(1), audio.squeeze(1)
+        )
+        loss_stft = sc_loss + mag_loss
+
+        self.log(
+            "train/generator/stft",
+            loss_stft,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=True,
+            logger=True,
+            sync_dist=True,
+        )
+
+        # L1 Mel-Spectrogram Loss
+        # This is not used in backpropagation currently
+        audio_mel = self.mel_transforms.loss(audio.squeeze(1))
+        fake_audio_mel = self.mel_transforms.loss(fake_audio.squeeze(1))
+        loss_mel = F.l1_loss(audio_mel, fake_audio_mel)
+
+        self.log(
+            "train/generator/mel",
+            loss_mel,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=True,
+            logger=True,
+            sync_dist=True,
+        )
+
+        # Now, we need to reduce the length of the audio to save memory
+        if self.crop_length is not None and audio.shape[2] > self.crop_length:
+            slice_idx = torch.randint(0, audio.shape[-1] - self.crop_length, (1,))
+
+            audio = audio[..., slice_idx : slice_idx + self.crop_length]
+            fake_audio = fake_audio[..., slice_idx : slice_idx + self.crop_length]
+            audio_mask = audio_mask[..., slice_idx : slice_idx + self.crop_length]
+
+            assert audio.shape == fake_audio.shape == audio_mask.shape
+
+        # Adv Loss
+        loss_adv_all = 0
+
+        for key, disc in self.discriminators.items():
+            score_fakes, feat_fake = disc(fake_audio)
+
+            # Adversarial Loss
+            score_fakes = torch.cat(score_fakes, dim=1)
+            loss_fake = torch.mean((1 - score_fakes) ** 2)
+
+            self.log(
+                f"train/generator/adv_{key}",
+                loss_fake,
+                on_step=True,
+                on_epoch=False,
+                prog_bar=False,
+                logger=True,
+                sync_dist=True,
+            )
+
+            loss_adv_all += loss_fake
+
+            if self.feature_matching is False:
+                continue
+
+            # Feature Matching Loss
+            _, feat_real = disc(audio)
+            loss_fm = 0
+            for dr, dg in zip(feat_real, feat_fake):
+                for rl, gl in zip(dr, dg):
+                    loss_fm += F.l1_loss(rl, gl)
+
+            loss_fm /= len(feat_real)
+
+            self.log(
+                f"train/generator/adv_fm_{key}",
+                loss_fm,
+                on_step=True,
+                on_epoch=False,
+                prog_bar=False,
+                logger=True,
+                sync_dist=True,
+            )
+
+            loss_adv_all += loss_fm
+
+        loss_adv_all /= len(self.discriminators)
+        loss_gen_all = base_loss + loss_stft * 2.5 + loss_mel * 45 + loss_adv_all
+
+        self.log(
+            "train/generator/all",
+            loss_gen_all,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=True,
+            logger=True,
+            sync_dist=True,
+        )
+
+        return loss_gen_all, audio, fake_audio
+
+    def training_discriminator(self, audio, fake_audio):
+        loss_disc_all = 0
+
+        for key, disc in self.discriminators.items():
+            if self.training and self.checkpointing:
+                scores, _ = gradient_checkpointing(disc, audio, use_reentrant=False)
+                score_fakes, _ = gradient_checkpointing(
+                    disc, fake_audio.detach(), use_reentrant=False
+                )
+            else:
+                scores, _ = disc(audio)
+                score_fakes, _ = disc(fake_audio.detach())
+
+            scores = torch.cat(scores, dim=1)
+            score_fakes = torch.cat(score_fakes, dim=1)
+            loss_disc = torch.mean((scores - 1) ** 2) + torch.mean((score_fakes) ** 2)
+
+            self.log(
+                f"train/discriminator/{key}",
+                loss_disc,
+                on_step=True,
+                on_epoch=False,
+                prog_bar=False,
+                logger=True,
+                sync_dist=True,
+            )
+
+            loss_disc_all += loss_disc
+
+        loss_disc_all /= len(self.discriminators)
+
+        self.log(
+            "train/discriminator/all",
+            loss_disc_all,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=True,
+            logger=True,
+            sync_dist=True,
+        )
+
+        return loss_disc_all
+
+    def training_step(self, batch, batch_idx):
+        optim_g, optim_d = self.optimizers()
+
+        audio, lengths = batch["audio"], batch["lengths"]
+        audio_mask = sequence_mask(lengths)[:, None, :].to(audio.device, torch.float32)
+
+        # Generator
+        optim_g.zero_grad()
+        loss_gen_all, audio, fake_audio = self.training_generator(audio, audio_mask)
+        self.manual_backward(loss_gen_all)
+
+        self.log(
+            "train/generator/grad_norm",
+            grad_norm(self.generator.parameters()),
+            on_step=True,
+            on_epoch=False,
+            prog_bar=False,
+            logger=True,
+            sync_dist=True,
+        )
+
+        self.clip_gradients(
+            optim_g, gradient_clip_val=1000, gradient_clip_algorithm="norm"
+        )
+        optim_g.step()
+
+        # Discriminator
+        assert fake_audio.shape == audio.shape
+
+        optim_d.zero_grad()
+        loss_disc_all = self.training_discriminator(audio, fake_audio)
+        self.manual_backward(loss_disc_all)
+
+        for key, disc in self.discriminators.items():
+            self.log(
+                f"train/discriminator/grad_norm_{key}",
+                grad_norm(disc.parameters()),
+                on_step=True,
+                on_epoch=False,
+                prog_bar=False,
+                logger=True,
+                sync_dist=True,
+            )
+
+        self.clip_gradients(
+            optim_d, gradient_clip_val=1000, gradient_clip_algorithm="norm"
+        )
+        optim_d.step()
+
+        # Manual LR Scheduler
+        scheduler_g, scheduler_d = self.lr_schedulers()
+        scheduler_g.step()
+        scheduler_d.step()
+
+    def forward(self, audio, mask=None, input_spec=None):
+        if input_spec is None:
+            input_spec = self.mel_transforms.input(audio.squeeze(1))
+
+        fake_audio = self.generator(input_spec)
+
+        return fake_audio, 0
+
+    def validation_step(self, batch: Any, batch_idx: int):
+        audio, lengths = batch["audio"], batch["lengths"]
+        audio_mask = sequence_mask(lengths)[:, None, :].to(audio.device, torch.float32)
+
+        # Generator
+        fake_audio, _ = self.forward(audio, audio_mask)
+        assert fake_audio.shape == audio.shape
+
+        # Apply mask
+        audio = audio * audio_mask
+        fake_audio = fake_audio * audio_mask
+
+        # L1 Mel-Spectrogram Loss
+        audio_mel = self.mel_transforms.loss(audio.squeeze(1))
+        fake_audio_mel = self.mel_transforms.loss(fake_audio.squeeze(1))
+        loss_mel = F.l1_loss(audio_mel, fake_audio_mel)
+
+        self.log(
+            "val/metrics/mel",
+            loss_mel,
+            on_step=False,
+            on_epoch=True,
+            prog_bar=True,
+            logger=True,
+            sync_dist=True,
+        )
+
+        # Report other metrics
+        self.report_val_metrics(fake_audio, audio, lengths)

+ 573 - 0
fish_speech/models/hubert_vq/modules.py

@@ -0,0 +1,573 @@
+import math
+
+import torch
+from torch import nn
+from torch.nn import Conv1d, Conv2d, ConvTranspose1d
+from torch.nn import functional as F
+from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
+
+from fish_speech.models.hubert_vq.utils import (
+    convert_pad_shape,
+    get_padding,
+    init_weights,
+)
+
+LRELU_SLOPE = 0.1
+
+
+class VQEncoder(nn.Module):
+    def __init__(self, *args, **kwargs) -> None:
+        super().__init__(*args, **kwargs)
+
+        encoder_layer = nn.TransformerEncoderLayer(
+            d_model=256, nhead=4, dim_feedforward=1024, dropout=0.1, activation="gelu"
+        )
+        self.encoder = nn.TransformerEncoder(
+            encoder_layer, num_layers=6, norm=nn.LayerNorm(256)
+        )
+
+
+class RelativeAttention(nn.Module):
+    def __init__(
+        self,
+        channels,
+        n_heads,
+        p_dropout=0.0,
+        window_size=4,
+        window_heads_share=True,
+        proximal_init=True,
+        proximal_bias=False,
+    ):
+        super().__init__()
+        assert channels % n_heads == 0
+
+        self.channels = channels
+        self.n_heads = n_heads
+        self.p_dropout = p_dropout
+        self.window_size = window_size
+        self.heads_share = window_heads_share
+        self.proximal_init = proximal_init
+        self.proximal_bias = proximal_bias
+
+        self.k_channels = channels // n_heads
+        self.qkv = nn.Linear(channels, channels * 3)
+        self.drop = nn.Dropout(p_dropout)
+
+        if window_size is not None:
+            n_heads_rel = 1 if window_heads_share else n_heads
+            rel_stddev = self.k_channels**-0.5
+            self.emb_rel_k = nn.Parameter(
+                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+                * rel_stddev
+            )
+            self.emb_rel_v = nn.Parameter(
+                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+                * rel_stddev
+            )
+
+        nn.init.xavier_uniform_(self.qkv.weight)
+
+        if proximal_init:
+            with torch.no_grad():
+                # Sync qk weights
+                self.qkv.weight.data[: self.channels] = self.qkv.weight.data[
+                    self.channels : self.channels * 2
+                ]
+                self.qkv.bias.data[: self.channels] = self.qkv.bias.data[
+                    self.channels : self.channels * 2
+                ]
+
+    def forward(self, x, key_padding_mask=None):
+        # x: (batch, seq_len, channels)
+        batch_size, seq_len, _ = x.size()
+        qkv = (
+            self.qkv(x)
+            .reshape(batch_size, seq_len, 3, self.n_heads, self.k_channels)
+            .permute(2, 0, 3, 1, 4)
+        )
+        query, key, value = torch.unbind(qkv, dim=0)
+
+        scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
+
+        if self.window_size is not None:
+            key_relative_embeddings = self._get_relative_embeddings(
+                self.emb_rel_k, seq_len
+            )
+            rel_logits = self._matmul_with_relative_keys(
+                query / math.sqrt(self.k_channels), key_relative_embeddings
+            )
+            scores_local = self._relative_position_to_absolute_position(rel_logits)
+            scores = scores + scores_local
+
+        if self.proximal_bias:
+            scores = scores + self._attention_bias_proximal(seq_len).to(
+                device=scores.device, dtype=scores.dtype
+            )
+
+        # key_padding_mask: (batch, seq_len)
+        if key_padding_mask is not None:
+            assert key_padding_mask.size() == (
+                batch_size,
+                seq_len,
+            ), f"key_padding_mask shape {key_padding_mask.size()} is not (batch_size, seq_len)"
+            assert (
+                key_padding_mask.dtype == torch.bool
+            ), f"key_padding_mask dtype {key_padding_mask.dtype} is not bool"
+
+            key_padding_mask = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(
+                -1, self.n_heads, -1, -1
+            )
+            print(key_padding_mask.shape, scores.shape)
+            scores = scores.masked_fill(key_padding_mask, float("-inf"))
+
+            print(scores[0, 0])
+
+        p_attn = F.softmax(scores, dim=-1)  # [b, n_h, t_t, t_s]
+        p_attn = self.drop(p_attn)
+        output = torch.matmul(p_attn, value)
+
+        if self.window_size is not None:
+            relative_weights = self._absolute_position_to_relative_position(p_attn)
+            value_relative_embeddings = self._get_relative_embeddings(
+                self.emb_rel_v, seq_len
+            )
+            output = output + self._matmul_with_relative_values(
+                relative_weights, value_relative_embeddings
+            )
+
+        return output.reshape(batch_size, seq_len, self.n_heads * self.k_channels)
+
+    def _matmul_with_relative_values(self, x, y):
+        """
+        x: [b, h, l, m]
+        y: [h or 1, m, d]
+        ret: [b, h, l, d]
+        """
+        ret = torch.matmul(x, y.unsqueeze(0))
+        return ret
+
+    def _matmul_with_relative_keys(self, x, y):
+        """
+        x: [b, h, l, d]
+        y: [h or 1, m, d]
+        ret: [b, h, l, m]
+        """
+        ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
+        return ret
+
+    def _get_relative_embeddings(self, relative_embeddings, length):
+        max_relative_position = 2 * self.window_size + 1
+        # Pad first before slice to avoid using cond ops.
+        pad_length = max(length - (self.window_size + 1), 0)
+        slice_start_position = max((self.window_size + 1) - length, 0)
+        slice_end_position = slice_start_position + 2 * length - 1
+        if pad_length > 0:
+            padded_relative_embeddings = F.pad(
+                relative_embeddings,
+                convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
+            )
+        else:
+            padded_relative_embeddings = relative_embeddings
+        used_relative_embeddings = padded_relative_embeddings[
+            :, slice_start_position:slice_end_position
+        ]
+        return used_relative_embeddings
+
+    def _relative_position_to_absolute_position(self, x):
+        """
+        x: [b, h, l, 2*l-1]
+        ret: [b, h, l, l]
+        """
+        batch, heads, length, _ = x.size()
+        # Concat columns of pad to shift from relative to absolute indexing.
+        x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
+
+        # Concat extra elements so to add up to shape (len+1, 2*len-1).
+        x_flat = x.view([batch, heads, length * 2 * length])
+        x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
+
+        # Reshape and slice out the padded elements.
+        x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
+            :, :, :length, length - 1 :
+        ]
+        return x_final
+
+    def _absolute_position_to_relative_position(self, x):
+        """
+        x: [b, h, l, l]
+        ret: [b, h, l, 2*l-1]
+        """
+        batch, heads, length, _ = x.size()
+        # pad along column
+        x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
+        x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
+        # add 0's in the beginning that will skew the elements after reshape
+        x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
+        x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
+        return x_final
+
+    def _attention_bias_proximal(self, length):
+        """Bias for self-attention to encourage attention to close positions.
+        Args:
+          length: an integer scalar.
+        Returns:
+          a Tensor with shape [1, 1, length, length]
+        """
+        r = torch.arange(length, dtype=torch.float32)
+        diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
+        return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
+
+
+class ResBlock1(torch.nn.Module):
+    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
+        super(ResBlock1, self).__init__()
+        self.convs1 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[2],
+                        padding=get_padding(kernel_size, dilation[2]),
+                    )
+                ),
+            ]
+        )
+        self.convs1.apply(init_weights)
+
+        self.convs2 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+            ]
+        )
+        self.convs2.apply(init_weights)
+
+    def forward(self, x, x_mask=None):
+        for c1, c2 in zip(self.convs1, self.convs2):
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            if x_mask is not None:
+                xt = xt * x_mask
+            xt = c1(xt)
+            xt = F.leaky_relu(xt, LRELU_SLOPE)
+            if x_mask is not None:
+                xt = xt * x_mask
+            xt = c2(xt)
+            x = xt + x
+        if x_mask is not None:
+            x = x * x_mask
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs1:
+            remove_weight_norm(l)
+        for l in self.convs2:
+            remove_weight_norm(l)
+
+
+class ResBlock2(torch.nn.Module):
+    def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
+        super(ResBlock2, self).__init__()
+        self.convs = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+            ]
+        )
+        self.convs.apply(init_weights)
+
+    def forward(self, x, x_mask=None):
+        for c in self.convs:
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            if x_mask is not None:
+                xt = xt * x_mask
+            xt = c(xt)
+            x = xt + x
+        if x_mask is not None:
+            x = x * x_mask
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs:
+            remove_weight_norm(l)
+
+
+class Generator(nn.Module):
+    def __init__(
+        self,
+        initial_channel,
+        resblock,
+        resblock_kernel_sizes,
+        resblock_dilation_sizes,
+        upsample_rates,
+        upsample_initial_channel,
+        upsample_kernel_sizes,
+        gin_channels=0,
+    ):
+        super(Generator, self).__init__()
+        self.num_kernels = len(resblock_kernel_sizes)
+        self.num_upsamples = len(upsample_rates)
+        self.conv_pre = Conv1d(
+            initial_channel, upsample_initial_channel, 7, 1, padding=3
+        )
+        resblock = ResBlock1 if resblock == "1" else ResBlock2
+
+        self.ups = nn.ModuleList()
+        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+            self.ups.append(
+                weight_norm(
+                    ConvTranspose1d(
+                        upsample_initial_channel // (2**i),
+                        upsample_initial_channel // (2 ** (i + 1)),
+                        k,
+                        u,
+                        padding=(k - u) // 2,
+                    )
+                )
+            )
+
+        self.resblocks = nn.ModuleList()
+        for i in range(len(self.ups)):
+            ch = upsample_initial_channel // (2 ** (i + 1))
+            for j, (k, d) in enumerate(
+                zip(resblock_kernel_sizes, resblock_dilation_sizes)
+            ):
+                self.resblocks.append(resblock(ch, k, d))
+
+        self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
+        self.ups.apply(init_weights)
+
+        if gin_channels != 0:
+            self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
+
+    def forward(self, x, g=None):
+        x = self.conv_pre(x)
+        if g is not None:
+            x = x + self.cond(g)
+
+        for i in range(self.num_upsamples):
+            x = F.leaky_relu(x, LRELU_SLOPE)
+            x = self.ups[i](x)
+            xs = None
+            for j in range(self.num_kernels):
+                if xs is None:
+                    xs = self.resblocks[i * self.num_kernels + j](x)
+                else:
+                    xs += self.resblocks[i * self.num_kernels + j](x)
+            x = xs / self.num_kernels
+        x = F.leaky_relu(x)
+        x = self.conv_post(x)
+        x = torch.tanh(x)
+
+        return x
+
+    def remove_weight_norm(self):
+        print("Removing weight norm...")
+        for l in self.ups:
+            remove_weight_norm(l)
+        for l in self.resblocks:
+            l.remove_weight_norm()
+
+
+class DiscriminatorP(nn.Module):
+    def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
+        super(DiscriminatorP, self).__init__()
+        self.period = period
+        self.use_spectral_norm = use_spectral_norm
+        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+        self.convs = nn.ModuleList(
+            [
+                norm_f(
+                    Conv2d(
+                        1,
+                        32,
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+                norm_f(
+                    Conv2d(
+                        32,
+                        128,
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+                norm_f(
+                    Conv2d(
+                        128,
+                        512,
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+                norm_f(
+                    Conv2d(
+                        512,
+                        1024,
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+                norm_f(
+                    Conv2d(
+                        1024,
+                        1024,
+                        (kernel_size, 1),
+                        1,
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+            ]
+        )
+        self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
+
+    def forward(self, x):
+        fmap = []
+
+        # 1d to 2d
+        b, c, t = x.shape
+        if t % self.period != 0:  # pad first
+            n_pad = self.period - (t % self.period)
+            x = F.pad(x, (0, n_pad), "reflect")
+            t = t + n_pad
+        x = x.view(b, c, t // self.period, self.period)
+
+        for l in self.convs:
+            x = l(x)
+            x = F.leaky_relu(x, LRELU_SLOPE)
+            fmap.append(x)
+        x = self.conv_post(x)
+        fmap.append(x)
+        x = torch.flatten(x, 1, -1)
+
+        return x, fmap
+
+
+class DiscriminatorS(nn.Module):
+    def __init__(self, use_spectral_norm=False):
+        super(DiscriminatorS, self).__init__()
+        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+        self.convs = nn.ModuleList(
+            [
+                norm_f(Conv1d(1, 16, 15, 1, padding=7)),
+                norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
+                norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
+                norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
+                norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
+                norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
+            ]
+        )
+        self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
+
+    def forward(self, x):
+        fmap = []
+
+        for l in self.convs:
+            x = l(x)
+            x = F.leaky_relu(x, LRELU_SLOPE)
+            fmap.append(x)
+        x = self.conv_post(x)
+        fmap.append(x)
+        x = torch.flatten(x, 1, -1)
+
+        return x, fmap
+
+
+class EnsembleDiscriminator(nn.Module):
+    def __init__(self, use_spectral_norm=False):
+        super(EnsembleDiscriminator, self).__init__()
+        periods = [2, 3, 5, 7, 11]
+
+        discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
+        discs = discs + [
+            DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
+        ]
+        self.discriminators = nn.ModuleList(discs)
+
+    def forward(self, y, y_hat):
+        y_d_rs = []
+        y_d_gs = []
+        fmap_rs = []
+        fmap_gs = []
+        for i, d in enumerate(self.discriminators):
+            y_d_r, fmap_r = d(y)
+            y_d_g, fmap_g = d(y_hat)
+            y_d_rs.append(y_d_r)
+            y_d_gs.append(y_d_g)
+            fmap_rs.append(fmap_r)
+            fmap_gs.append(fmap_g)
+
+        return y_d_rs, y_d_gs, fmap_rs, fmap_gs

+ 163 - 0
fish_speech/models/hubert_vq/utils.py

@@ -0,0 +1,163 @@
+import torch
+import torch.utils.data
+from librosa.filters import mel as librosa_mel_fn
+
+
+def convert_pad_shape(pad_shape):
+    l = pad_shape[::-1]
+    pad_shape = [item for sublist in l for item in sublist]
+    return pad_shape
+
+
+def sequence_mask(length, max_length=None):
+    if max_length is None:
+        max_length = length.max()
+    x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+    return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+def init_weights(m, mean=0.0, std=0.01):
+    classname = m.__class__.__name__
+    if classname.find("Conv") != -1:
+        m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+    return int((kernel_size * dilation - dilation) / 2)
+
+
+def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
+    """
+    PARAMS
+    ------
+    C: compression factor
+    """
+    return torch.log(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression_torch(x, C=1):
+    """
+    PARAMS
+    ------
+    C: compression factor used to compress
+    """
+    return torch.exp(x) / C
+
+
+def spectral_normalize_torch(magnitudes):
+    output = dynamic_range_compression_torch(magnitudes)
+    return output
+
+
+def spectral_de_normalize_torch(magnitudes):
+    output = dynamic_range_decompression_torch(magnitudes)
+    return output
+
+
+mel_basis = {}
+hann_window = {}
+
+
+def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
+    if torch.min(y) < -1.0:
+        print("min value is ", torch.min(y))
+    if torch.max(y) > 1.0:
+        print("max value is ", torch.max(y))
+
+    global hann_window
+    dtype_device = str(y.dtype) + "_" + str(y.device)
+    wnsize_dtype_device = str(win_size) + "_" + dtype_device
+    if wnsize_dtype_device not in hann_window:
+        hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
+            dtype=y.dtype, device=y.device
+        )
+
+    y = torch.nn.functional.pad(
+        y.unsqueeze(1),
+        (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
+        mode="reflect",
+    )
+    y = y.squeeze(1)
+    spec = torch.stft(
+        y,
+        n_fft,
+        hop_length=hop_size,
+        win_length=win_size,
+        window=hann_window[wnsize_dtype_device],
+        center=center,
+        pad_mode="reflect",
+        normalized=False,
+        onesided=True,
+        return_complex=False,
+    )
+
+    spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
+    return spec
+
+
+def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
+    global mel_basis
+    dtype_device = str(spec.dtype) + "_" + str(spec.device)
+    fmax_dtype_device = str(fmax) + "_" + dtype_device
+    if fmax_dtype_device not in mel_basis:
+        mel = librosa_mel_fn(
+            sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
+        )
+        mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
+            dtype=spec.dtype, device=spec.device
+        )
+    spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
+    spec = spectral_normalize_torch(spec)
+    return spec
+
+
+def mel_spectrogram_torch(
+    y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
+):
+    if torch.min(y) < -1.0:
+        print("min value is ", torch.min(y))
+    if torch.max(y) > 1.0:
+        print("max value is ", torch.max(y))
+
+    global mel_basis, hann_window
+    dtype_device = str(y.dtype) + "_" + str(y.device)
+    fmax_dtype_device = str(fmax) + "_" + dtype_device
+    wnsize_dtype_device = str(win_size) + "_" + dtype_device
+    if fmax_dtype_device not in mel_basis:
+        mel = librosa_mel_fn(
+            sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
+        )
+        mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
+            dtype=y.dtype, device=y.device
+        )
+    if wnsize_dtype_device not in hann_window:
+        hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
+            dtype=y.dtype, device=y.device
+        )
+
+    y = torch.nn.functional.pad(
+        y.unsqueeze(1),
+        (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
+        mode="reflect",
+    )
+    y = y.squeeze(1)
+
+    spec = torch.stft(
+        y,
+        n_fft,
+        hop_length=hop_size,
+        win_length=win_size,
+        window=hann_window[wnsize_dtype_device],
+        center=center,
+        pad_mode="reflect",
+        normalized=False,
+        onesided=True,
+        return_complex=False,
+    )
+
+    spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
+
+    spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
+    spec = spectral_normalize_torch(spec)
+
+    return spec

+ 0 - 216
fish_speech/models/whisper_vq.py

@@ -1,216 +0,0 @@
-from dataclasses import dataclass
-from typing import Optional
-
-import torch
-from torch import nn
-from vector_quantize_pytorch import VectorQuantize
-
-from fish_speech.modules.flash_whisper import (
-    FlashWhisperEncoderLayer,
-    FlashWhisperForConditionalGeneration,
-)
-
-
-@dataclass
-class WhisperVQOutput:
-    loss: torch.Tensor
-    metrics: dict[str, torch.Tensor]
-
-
-class WhisperVQ(nn.Module):
-    def __init__(
-        self,
-        model_name_or_path: str = "openai/whisper-medium",
-        # Quantization
-        codebook_dim: int = 32,
-        codebook_size: int = 4096,
-        codebook_decay: float = 0.9,
-        threshold_ema_dead_code: int = 0,
-        use_cosine_similarity: bool = True,
-        downsample: bool = True,
-        # Attention
-        post_attention_depth: int = 2,
-    ):
-        super().__init__()
-
-        self.whisper = FlashWhisperForConditionalGeneration.from_pretrained(
-            model_name_or_path
-        )
-        self.whisper.gradient_checkpointing_enable()
-
-        # Freeze Whisper
-        for param in self.whisper.parameters():
-            param.requires_grad = False
-
-        # Store vars
-        self.downsample = downsample
-        self.codebook_dim = codebook_dim
-        self.codebook_size = codebook_size
-
-        # Pre-quantization
-        whisper_config = self.whisper.model.config
-        encoder_width = whisper_config.encoder_attention_heads * 64
-
-        self.pre_ln = nn.LayerNorm(encoder_width)
-        self.pre_mlp = nn.Sequential(
-            nn.Linear(encoder_width, whisper_config.encoder_ffn_dim),
-            nn.GELU(),
-            nn.Linear(whisper_config.encoder_ffn_dim, encoder_width),
-        )
-
-        # Quantization
-        self.quantizer = VectorQuantize(
-            dim=encoder_width,
-            codebook_size=codebook_size,
-            codebook_dim=codebook_dim,
-            decay=codebook_decay,
-            commitment_weight=1.0,
-            threshold_ema_dead_code=threshold_ema_dead_code,
-            use_cosine_sim=use_cosine_similarity,
-        )
-        self.pad_embedding = nn.Parameter(torch.randn(encoder_width))
-
-        # Post-quantization
-        self.post_positional_embedding = nn.Embedding(
-            whisper_config.max_source_positions, encoder_width
-        )
-        self.post_attention = nn.Sequential(
-            *[
-                FlashWhisperEncoderLayer(
-                    config=whisper_config,
-                )
-                for _ in range(post_attention_depth)
-            ]
-        )
-        self.post_ln = nn.LayerNorm(encoder_width)
-
-    def encode(
-        self,
-        input_features: Optional[torch.Tensor],
-        attention_mask: Optional[torch.Tensor] = None,
-    ) -> torch.Tensor:
-        if attention_mask is not None:
-            assert attention_mask.ndim == 2, "Attention mask must be 2D"
-
-            # Whisper will downsample by 2
-            attention_mask = attention_mask[:, ::2]
-
-        with torch.no_grad():
-            hidden_states = self.whisper.model.encoder(
-                input_features,
-            ).last_hidden_state
-
-            x = hidden_states
-            if self.downsample:
-                x = x.reshape(x.shape[0], x.shape[1] // 2, 2, x.shape[2]).mean(dim=2)
-
-                if attention_mask is not None:
-                    attention_mask = attention_mask[:, ::2]
-
-        x = x + self.pre_mlp(self.pre_ln(x))
-        quantized, indices, loss = self.quantizer(
-            x, mask=attention_mask.bool() if attention_mask is not None else None
-        )
-
-        # Fill masked positions with pad embedding
-        if attention_mask is not None:
-            quantized[attention_mask == 0] = self.pad_embedding
-
-        return quantized, indices, loss, hidden_states
-
-    def decode(
-        self,
-        hidden_states: torch.Tensor,
-    ) -> torch.Tensor:
-        # Upsample
-        if self.downsample:
-            hidden_states = hidden_states.repeat_interleave(2, dim=1)
-
-        # Inject position embeddings
-        positions = torch.arange(
-            0, hidden_states.shape[1], dtype=torch.long, device=hidden_states.device
-        )
-        x = hidden_states + self.post_positional_embedding(positions)
-
-        # Decode
-        for layer in self.post_attention:
-            x = layer(x, None, None)[0]
-        hidden_states = self.post_ln(hidden_states)
-
-        return hidden_states
-
-    def forward(
-        self,
-        input_features: torch.Tensor,
-        encoder_attention_mask: torch.Tensor,
-        decoder_input_ids: torch.Tensor,
-        decoder_attention_mask: torch.Tensor,
-        labels: torch.Tensor,
-        # Audio, not used here
-        input_values: Optional[torch.Tensor] = None,
-    ) -> WhisperVQOutput:
-        quantize, _, vq_loss, teacher_hidden_states = self.encode(
-            input_features=input_features,
-            attention_mask=encoder_attention_mask,
-        )
-        vq_hidden_states = self.decode(quantize)
-
-        # student cross entropy loss
-        outputs = self.whisper(
-            encoder_outputs=(vq_hidden_states,),
-            decoder_input_ids=decoder_input_ids,
-            decoder_attention_mask=decoder_attention_mask,
-            labels=labels,
-        )
-        student_ce_loss = outputs.loss
-        student_logits = outputs.logits
-
-        # teacher cross entropy loss
-        with torch.no_grad():
-            outputs = self.whisper(
-                encoder_outputs=(teacher_hidden_states,),
-                decoder_input_ids=decoder_input_ids,
-                decoder_attention_mask=decoder_attention_mask,
-                labels=labels,
-            )
-            teacher_ce_loss = outputs.loss
-            teacher_logits = outputs.logits
-
-        # KL divergence
-        kl_loss = nn.functional.kl_div(
-            nn.functional.log_softmax(student_logits, dim=-1),
-            nn.functional.softmax(teacher_logits, dim=-1),
-            reduction="batchmean",
-        )
-
-        loss = vq_loss + student_ce_loss + kl_loss
-
-        return WhisperVQOutput(
-            loss=loss,
-            metrics={
-                "vq_loss": vq_loss,
-                "student_ce_loss": student_ce_loss,
-                "teacher_ce_loss": teacher_ce_loss,
-                "kl_loss": kl_loss,
-            },
-        )
-
-
-if __name__ == "__main__":
-    from torch.utils.data import DataLoader
-    from transformers import WhisperProcessor
-
-    from fish_speech.datasets.whisper_vq import WhisperVQCollator, WhisperVQDataset
-
-    processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
-    model = WhisperVQ()
-
-    ds = WhisperVQDataset(
-        "filelists/whisper-vq.train.test.filelist", "openai/whisper-medium"
-    )
-    loader = DataLoader(ds, batch_size=8, collate_fn=WhisperVQCollator())
-
-    for batch in loader:
-        output = model(**batch)
-        print(output)
-        break

+ 0 - 313
fish_speech/modules/flash_whisper.py

@@ -1,313 +0,0 @@
-# A whisper that supports flash-attention and dynamic input length.
-from typing import Optional, Tuple, Union
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-from torch import nn
-from transformers.modeling_outputs import BaseModelOutput
-from transformers.models.whisper.modeling_whisper import (
-    WhisperAttention,
-    WhisperConfig,
-    WhisperDecoder,
-    WhisperDecoderLayer,
-    WhisperEncoder,
-    WhisperEncoderLayer,
-    WhisperForConditionalGeneration,
-    WhisperModel,
-)
-from transformers.utils import logging
-
-logger = logging.get_logger(__name__)
-
-
-class FlashWhisperAttention(WhisperAttention):
-    """Multi-headed attention from 'Attention Is All You Need' paper"""
-
-    # Copied from transformers.models.bart.modeling_bart.BartAttention.forward with BART->whisper
-    def forward(
-        self,
-        hidden_states: torch.Tensor,
-        key_value_states: Optional[torch.Tensor] = None,
-        past_key_value: Optional[Tuple[torch.Tensor]] = None,
-        attention_mask: Optional[torch.Tensor] = None,
-        layer_head_mask: Optional[torch.Tensor] = None,
-        output_attentions: bool = False,
-    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
-        """Input shape: Batch x Time x Channel"""
-
-        # if key_value_states are provided this layer is used as a cross-attention layer
-        # for the decoder
-        is_cross_attention = key_value_states is not None
-
-        bsz, tgt_len, _ = hidden_states.size()
-
-        # get query proj - don't scale here since Flash Attention performs this under the hood
-        query_states = self._shape(self.q_proj(hidden_states), -1, bsz)
-
-        # get key, value proj
-        # `past_key_value[0].shape[2] == key_value_states.shape[1]`
-        # is checking that the `sequence_length` of the `past_key_value` is the same as
-        # the provided `key_value_states` to support prefix tuning
-        if (
-            is_cross_attention
-            and past_key_value is not None
-            and past_key_value[0].shape[2] == key_value_states.shape[1]
-        ):
-            # reuse k,v, cross_attentions
-            key_states = past_key_value[0]
-            value_states = past_key_value[1]
-        elif is_cross_attention:
-            # cross_attentions
-            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
-            value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
-        elif past_key_value is not None:
-            # reuse k, v, self_attention
-            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
-            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
-            key_states = torch.cat([past_key_value[0], key_states], dim=2)
-            value_states = torch.cat([past_key_value[1], value_states], dim=2)
-        else:
-            # self_attention
-            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
-            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
-
-        if self.is_decoder:
-            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
-            # Further calls to cross_attention layer can then reuse all cross-attention
-            # key/value_states (first "if" case)
-            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
-            # all previous decoder key/value_states. Further calls to uni-directional self-attention
-            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
-            # if encoder bi-directional self-attention `past_key_value` is always `None`
-            past_key_value = (key_states, value_states)
-
-        attn_output = F.scaled_dot_product_attention(
-            query=query_states,
-            key=key_states,
-            value=value_states,
-            attn_mask=attention_mask,
-            scale=self.scaling,
-        )
-
-        attn_output = attn_output.transpose(1, 2)
-
-        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
-        # partitioned across GPUs when using tensor-parallelism.
-        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
-
-        attn_output = self.out_proj(attn_output)
-
-        return attn_output, None, past_key_value
-
-
-# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper
-class FlashWhisperEncoderLayer(WhisperEncoderLayer):
-    def __init__(self, config: WhisperConfig):
-        super().__init__(config)
-
-        self.self_attn = FlashWhisperAttention(
-            embed_dim=self.embed_dim,
-            num_heads=config.encoder_attention_heads,
-            dropout=config.attention_dropout,
-        )
-
-
-class FlashWhisperDecoderLayer(WhisperDecoderLayer):
-    def __init__(self, config: WhisperConfig):
-        super().__init__(config)
-
-        self.self_attn = FlashWhisperAttention(
-            embed_dim=self.embed_dim,
-            num_heads=config.decoder_attention_heads,
-            dropout=config.attention_dropout,
-            is_decoder=True,
-        )
-
-
-class FlashWhisperEncoder(WhisperEncoder):
-    """
-    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
-    [`WhisperEncoderLayer`].
-
-    Args:
-        config: WhisperConfig
-    """
-
-    def __init__(self, config: WhisperConfig):
-        super().__init__(config)
-
-        self.layers = nn.ModuleList(
-            [FlashWhisperEncoderLayer(config) for _ in range(config.encoder_layers)]
-        )
-
-    def forward(
-        self,
-        input_features,
-        attention_mask=None,
-        head_mask=None,
-        output_attentions=None,
-        output_hidden_states=None,
-        return_dict=None,
-    ):
-        r"""
-        Args:
-            input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
-                Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
-                obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
-                `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
-                `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
-                and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
-            attention_mask (`torch.Tensor`)`, *optional*):
-                Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
-                but it is not used. By default the silence in the input log mel spectrogram are ignored.
-            head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
-                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
-
-                - 1 indicates the head is **not masked**,
-                - 0 indicates the head is **masked**.
-            output_attentions (`bool`, *optional*):
-                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
-                returned tensors for more detail.
-            output_hidden_states (`bool`, *optional*):
-                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
-                for more detail.
-            return_dict (`bool`, *optional*):
-                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
-        """
-
-        # If we receive the output of input feature directly, just return it
-        if input_features.shape[-2:] == (1500, 1024):
-            if not return_dict:
-                return (input_features,)
-
-            return BaseModelOutput(last_hidden_state=input_features)
-
-        output_attentions = (
-            output_attentions
-            if output_attentions is not None
-            else self.config.output_attentions
-        )
-        output_hidden_states = (
-            output_hidden_states
-            if output_hidden_states is not None
-            else self.config.output_hidden_states
-        )
-        return_dict = (
-            return_dict if return_dict is not None else self.config.use_return_dict
-        )
-        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
-        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
-
-        inputs_embeds = inputs_embeds.permute(0, 2, 1)
-        embed_pos = self.embed_positions.weight
-
-        hidden_states = inputs_embeds + embed_pos[None, : inputs_embeds.size(1), :]
-        hidden_states = nn.functional.dropout(
-            hidden_states, p=self.dropout, training=self.training
-        )
-
-        encoder_states = () if output_hidden_states else None
-        all_attentions = () if output_attentions else None
-
-        # check if head_mask has a correct number of layers specified if desired
-        if head_mask is not None:
-            assert head_mask.size()[0] == (
-                len(self.layers)
-            ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
-
-        for idx, encoder_layer in enumerate(self.layers):
-            if output_hidden_states:
-                encoder_states = encoder_states + (hidden_states,)
-            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
-            to_drop = False
-            if self.training:
-                dropout_probability = torch.rand([])
-                if dropout_probability < self.layerdrop:  # skip the layer
-                    to_drop = True
-
-            if to_drop:
-                layer_outputs = (None, None)
-            else:
-                if self.gradient_checkpointing and self.training:
-
-                    def create_custom_forward(module):
-                        def custom_forward(*inputs):
-                            return module(*inputs, output_attentions)
-
-                        return custom_forward
-
-                    layer_outputs = torch.utils.checkpoint.checkpoint(
-                        create_custom_forward(encoder_layer),
-                        hidden_states,
-                        None,
-                        (head_mask[idx] if head_mask is not None else None),
-                    )
-                else:
-                    layer_outputs = encoder_layer(
-                        hidden_states,
-                        None,
-                        layer_head_mask=(
-                            head_mask[idx] if head_mask is not None else None
-                        ),
-                        output_attentions=output_attentions,
-                    )
-
-                hidden_states = layer_outputs[0]
-
-            if output_attentions:
-                all_attentions = all_attentions + (layer_outputs[1],)
-
-        hidden_states = self.layer_norm(hidden_states)
-
-        # Simply set states to zero for attention_mask
-        # hidden_states[:, 40:, :] = 0
-
-        if output_hidden_states:
-            encoder_states = encoder_states + (hidden_states,)
-
-        if not return_dict:
-            return tuple(
-                v
-                for v in [hidden_states, encoder_states, all_attentions]
-                if v is not None
-            )
-        return BaseModelOutput(
-            last_hidden_state=hidden_states,
-            hidden_states=encoder_states,
-            attentions=all_attentions,
-        )
-
-
-class FlashWhisperDecoder(WhisperDecoder):
-    """
-    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a
-    [`WhisperDecoderLayer`]
-
-    Args:
-        config: WhisperConfig
-    """
-
-    def __init__(self, config: WhisperConfig):
-        super().__init__(config)
-
-        self.layers = nn.ModuleList(
-            [FlashWhisperDecoderLayer(config) for _ in range(config.decoder_layers)]
-        )
-
-
-class FlashWhisperModel(WhisperModel):
-    def __init__(self, config: WhisperConfig):
-        super().__init__(config)
-
-        self.encoder = FlashWhisperEncoder(config)
-        self.decoder = FlashWhisperDecoder(config)
-        self.post_init()
-
-
-class FlashWhisperForConditionalGeneration(WhisperForConditionalGeneration):
-    def __init__(self, config: WhisperConfig):
-        super().__init__(config)
-
-        self.model = FlashWhisperModel(config)
-        self.post_init()

+ 40 - 0
pyproject.toml

@@ -0,0 +1,40 @@
+[project]
+name = "fish-speech"
+version = "0.1.0"
+authors = [
+    {name = "Lengyue", email = "lengyue@lengyue.me"},
+]
+description = "Fish Speech"
+readme = "README.md"
+requires-python = ">=3.10"
+keywords = ["TTS", "Speech"]
+license = {text = "BSD-3-Clause"}
+classifiers = [
+    "Programming Language :: Python :: 3",
+]
+dependencies = [
+    "transformers>=4.34.1",
+    "datasets>=2.14.5",
+    "bitsandbytes>=0.41.1",
+    "peft>=0.5.0",
+    "lightning>=2.1.0",
+    "hydra-core>=1.3.2",
+    "tensorboard>=2.14.1",
+    "natsort>=8.4.0",
+    "einops>=0.7.0",
+    "librosa>=0.10.1",
+    "vector-quantize-pytorch>=1.9.18",
+    "rich>=13.5.3",
+    "cn2an",
+    "pypinyin",
+    "jieba",
+    "g2p_en",
+    "pyopenjtalk",
+]
+
+[build-system]
+requires = ["setuptools", "setuptools-scm"]
+build-backend = "setuptools.build_meta"
+
+[tool.setuptools]
+packages = ["fish_speech"]

+ 0 - 17
requirements.txt

@@ -1,17 +0,0 @@
-transformers>=4.34.1
-datasets>=2.14.5
-bitsandbytes>=0.41.1
-peft>=0.5.0
-lightning>=2.1.0
-hydra-core>=1.3.2
-tensorboard>=2.14.1
-natsort>=8.4.0
-einops>=0.7.0
-librosa>=0.10.1
-vector-quantize-pytorch>=1.9.18
-rich>=13.5.3
-cn2an
-pypinyin
-jieba
-g2p_en
-pyopenjtalk

+ 0 - 7
setup.py

@@ -1,7 +0,0 @@
-from setuptools import find_packages, setup
-
-setup(
-    name="fish-speech",
-    version="0.1.0",
-    packages=find_packages(include=["fish_speech", "fish_speech.*"]),
-)