소스 검색

Update diffusion code

Lengyue 2 년 전
부모
커밋
5e80bcd9e0

+ 1 - 1
fish_speech/configs/base.yaml

@@ -34,7 +34,7 @@ callbacks:
     _target_: lightning.pytorch.callbacks.ModelCheckpoint
     dirpath: ${paths.ckpt_dir}
     filename: "step_{step:09d}"
-    save_last: true # additionally always save an exact copy of the last checkpoint to a file last.ckpt
+    save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
     save_top_k: 5 # save 5 latest checkpoints
     monitor: step # use step to monitor checkpoints
     mode: max # save the latest checkpoint with the highest global_step

+ 19 - 9
fish_speech/configs/hubert_vq_diffusion.yaml

@@ -7,13 +7,13 @@ project: hubert_vq_diffusion
 # Lightning Trainer
 trainer:
   accelerator: gpu
-  devices: 4
+  devices: 1
   strategy:
     _target_: lightning.pytorch.strategies.DDPStrategy
     static_graph: true
   gradient_clip_val: 1.0
   gradient_clip_algorithm: 'norm'
-  precision: bf16-mixed
+  precision: 16-mixed
   max_steps: 1_000_000
   val_check_interval: 1000
 
@@ -41,8 +41,8 @@ data:
   _target_: fish_speech.datasets.vqgan.VQGANDataModule
   train_dataset: ${train_dataset}
   val_dataset: ${val_dataset}
-  num_workers: 4
-  batch_size: 8
+  num_workers: 0 #16
+  batch_size: 32
   val_batch_size: 4
 
 # Model Configuration
@@ -93,11 +93,21 @@ model:
   #   time_embedding_type: "positional"
 
   denoiser:
-    _target_: fish_speech.models.vq_diffusion.unet1d.Unet1DDenoiser
-    dim: 64
-    dim_mults: [1, 2, 4]
-    groups: 8
-    pe_scale: 1000
+    _target_: fish_speech.models.vq_diffusion.wavenet.WaveNet
+    in_channels: 128
+    out_channels: 128
+    d_encoder: 128
+    residual_channels: 512
+    residual_layers: 20
+    use_linear_bias: false
+    dilation_cycle: 2
+
+  # denoiser:
+  #   _target_: fish_speech.models.vq_diffusion.unet1d.Unet1DDenoiser
+  #   dim: 64
+  #   dim_mults: [1, 2, 4]
+  #   groups: 8
+  #   pe_scale: 1000
 
   vocoder:
     _target_: fish_speech.models.vq_diffusion.adamos.ADaMoSHiFiGANV1

+ 14 - 17
fish_speech/datasets/vqgan.py

@@ -29,7 +29,7 @@ class VQGANDataset(Dataset):
         self.files = [
             root / line.strip()
             for line in filelist.read_text().splitlines()
-            if ("Genshin" in line or "StarRail" in line)
+            # if ("Genshin" in line or "StarRail" in line)
         ]
         self.sample_rate = sample_rate
         self.hop_length = hop_length
@@ -48,26 +48,23 @@ class VQGANDataset(Dataset):
         if self.slice_frames is not None and features.shape[0] > self.slice_frames:
             start = np.random.randint(0, features.shape[0] - self.slice_frames)
             features = features[start : start + self.slice_frames]
-            feature_hop_length = features.shape[0] * (32000 // 50)
+
+            start_in_seconds, end_in_seconds = (
+                start * 320 / 16000,
+                (start + self.slice_frames) * 320 / 16000,
+            )
             audio = audio[
-                start
-                * feature_hop_length : (start + self.slice_frames)
-                * feature_hop_length
+                int(start_in_seconds * self.sample_rate) : int(
+                    end_in_seconds * self.sample_rate
+                )
             ]
 
-        # if features.shape[0] % 2 != 0:
-        #     features = features[:-1]
-
-        # if len(audio) > len(features) * self.hop_length:
-        #     audio = audio[: features.shape[0] * self.hop_length]
+        if len(audio) == 0:
+            return None
 
-        # if len(audio) < len(features) * self.hop_length:
-        #     audio = np.pad(
-        #         audio,
-        #         (0, len(features) * self.hop_length - len(audio)),
-        #         mode="constant",
-        #         constant_values=0,
-        #     )
+        max_value = np.abs(audio).max()
+        if max_value > 1.0:
+            audio = audio / max_value
 
         return {
             "audio": torch.from_numpy(audio),

+ 5 - 14
fish_speech/models/vq_diffusion/convnext_1d.py

@@ -88,19 +88,6 @@ class ConvNeXtBlock(nn.Module):
         return x
 
 
-@dataclass
-class ConvNext1DOutput(BaseOutput):
-    """
-    The output of [`UNet1DModel`].
-
-    Args:
-        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
-            The hidden states output from the last layer of the model.
-    """
-
-    sample: torch.FloatTensor
-
-
 class ConvNext1DModel(ModelMixin, ConfigMixin):
     r"""
     A ConvNext model that takes a noisy sample and a timestep and returns a sample shaped output.
@@ -176,6 +163,10 @@ class ConvNext1DModel(ModelMixin, ConfigMixin):
         self.in_proj = nn.Conv1d(in_channels, intermediate_dim, 1)
         self.out_proj = nn.Conv1d(intermediate_dim, out_channels, 1)
 
+        # Initialize weights
+        nn.init.normal_(self.out_proj.weight, mean=0, std=0.01)
+        nn.init.zeros_(self.out_proj.bias)
+
         # Blocks
         self.blocks = nn.ModuleList(
             [
@@ -199,7 +190,7 @@ class ConvNext1DModel(ModelMixin, ConfigMixin):
         timestep: Union[torch.Tensor, float, int],
         sample_mask: Optional[torch.Tensor] = None,
         condition: Optional[torch.Tensor] = None,
-    ) -> Union[ConvNext1DOutput, Tuple]:
+    ) -> torch.FloatTensor:
         r"""
         The [`ConvNext1DModel`] forward method.
 

+ 25 - 10
fish_speech/models/vq_diffusion/lit_module.py

@@ -44,7 +44,6 @@ class VQDiffusion(L.LightningModule):
         self.mel_transform = mel_transform
         self.noise_scheduler_train = DDIMScheduler(num_train_timesteps=1000)
         self.noise_scheduler_infer = UniPCMultistepScheduler(num_train_timesteps=1000)
-        self.noise_scheduler_infer.set_timesteps(20)
 
         # Modules
         self.vq_encoder = vq_encoder
@@ -73,10 +72,13 @@ class VQDiffusion(L.LightningModule):
         }
 
     def normalize_mels(self, x):
-        return (x + 11.5129251) / (1 + 11.5129251) * 2 - 1
+        # x is in range -10.1 to 3.1, normalize to -1 to 1
+        x_min, x_max = -10.1, 3.1
+        return (x - x_min) / (x_max - x_min) * 2 - 1
 
     def denormalize_mels(self, x):
-        return (x + 1) / 2 * (1.0 + 11.5129251) - 11.5129251
+        x_min, x_max = -10.1, 3.1
+        return (x + 1) / 2 * (x_max - x_min) + x_min
 
     def training_step(self, batch, batch_idx):
         audios, audio_lengths = batch["audios"], batch["audio_lengths"]
@@ -99,7 +101,7 @@ class VQDiffusion(L.LightningModule):
         )
 
         speaker_features = self.speaker_encoder(gt_mels, mel_masks)
-        vq_features, _ = self.vq_encoder(features, feature_masks)
+        vq_features, vq_loss = self.vq_encoder(features, feature_masks)
 
         # vq_features is 50 hz, need to convert to true mel size
         vq_features = F.interpolate(vq_features, size=gt_mels.shape[2], mode="nearest")
@@ -127,13 +129,26 @@ class VQDiffusion(L.LightningModule):
         model_output = self.denoiser(noisy_images, timesteps, mel_masks, text_features)
 
         # MSE loss without the mask
-        loss = (
-            (model_output * mel_masks - normalized_gt_mels * mel_masks) ** 2
+        # noise_loss = (
+        #     (model_output * mel_masks - normalized_gt_mels * mel_masks) ** 2
+        # ).sum() / (mel_masks.sum() * gt_mels.shape[1])
+        noise_loss = torch.abs(
+            model_output * mel_masks - normalized_gt_mels * mel_masks
         ).sum() / (mel_masks.sum() * gt_mels.shape[1])
 
         self.log(
-            "train/loss",
-            loss,
+            "train/noise_loss",
+            noise_loss,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=True,
+            logger=True,
+            sync_dist=True,
+        )
+
+        self.log(
+            "train/vq_loss",
+            vq_loss,
             on_step=True,
             on_epoch=False,
             prog_bar=True,
@@ -141,7 +156,7 @@ class VQDiffusion(L.LightningModule):
             sync_dist=True,
         )
 
-        return loss
+        return noise_loss + vq_loss
 
     def validation_step(self, batch: Any, batch_idx: int):
         audios, audio_lengths = batch["audios"], batch["audio_lengths"]
@@ -169,7 +184,7 @@ class VQDiffusion(L.LightningModule):
 
         # Begin sampling
         sampled_mels = torch.randn_like(gt_mels)
-        self.noise_scheduler_infer.set_timesteps(20)
+        self.noise_scheduler_infer.set_timesteps(100)
 
         for t in self.noise_scheduler_infer.timesteps:
             timesteps = torch.tensor([t], device=sampled_mels.device, dtype=torch.long)

+ 186 - 0
fish_speech/models/vq_diffusion/wavenet.py

@@ -0,0 +1,186 @@
+import math
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class Mish(nn.Module):
+    def forward(self, x):
+        return x * torch.tanh(F.softplus(x))
+
+
+class DiffusionEmbedding(nn.Module):
+    """Diffusion Step Embedding"""
+
+    def __init__(self, d_denoiser):
+        super(DiffusionEmbedding, self).__init__()
+        self.dim = d_denoiser
+
+    def forward(self, x):
+        device = x.device
+        half_dim = self.dim // 2
+        emb = math.log(10000) / (half_dim - 1)
+        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
+        emb = x[:, None] * emb[None, :]
+        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+        return emb
+
+
+class LinearNorm(nn.Module):
+    """LinearNorm Projection"""
+
+    def __init__(self, in_features, out_features, bias=False):
+        super(LinearNorm, self).__init__()
+        self.linear = nn.Linear(in_features, out_features, bias)
+
+        nn.init.xavier_uniform_(self.linear.weight)
+        if bias:
+            nn.init.constant_(self.linear.bias, 0.0)
+
+    def forward(self, x):
+        x = self.linear(x)
+        return x
+
+
+class ConvNorm(nn.Module):
+    """1D Convolution"""
+
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        kernel_size=1,
+        stride=1,
+        padding=None,
+        dilation=1,
+        bias=True,
+        w_init_gain="linear",
+    ):
+        super(ConvNorm, self).__init__()
+
+        if padding is None:
+            assert kernel_size % 2 == 1
+            padding = int(dilation * (kernel_size - 1) / 2)
+
+        self.conv = nn.Conv1d(
+            in_channels,
+            out_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=padding,
+            dilation=dilation,
+            bias=bias,
+        )
+        nn.init.kaiming_normal_(self.conv.weight)
+
+    def forward(self, signal):
+        conv_signal = self.conv(signal)
+
+        return conv_signal
+
+
+class ResidualBlock(nn.Module):
+    """Residual Block"""
+
+    def __init__(self, d_encoder, residual_channels, use_linear_bias=False, dilation=1):
+        super(ResidualBlock, self).__init__()
+        self.conv_layer = ConvNorm(
+            residual_channels,
+            2 * residual_channels,
+            kernel_size=3,
+            stride=1,
+            padding=dilation,
+            dilation=dilation,
+        )
+        self.diffusion_projection = LinearNorm(
+            residual_channels, residual_channels, use_linear_bias
+        )
+        self.conditioner_projection = ConvNorm(
+            d_encoder, 2 * residual_channels, kernel_size=1
+        )
+        self.output_projection = ConvNorm(
+            residual_channels, 2 * residual_channels, kernel_size=1
+        )
+
+    def forward(self, x, conditioner, diffusion_step):
+        diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
+        conditioner = self.conditioner_projection(conditioner)
+
+        y = x + diffusion_step
+
+        y = self.conv_layer(y) + conditioner
+
+        gate, filter = torch.chunk(y, 2, dim=1)
+        y = torch.sigmoid(gate) * torch.tanh(filter)
+
+        y = self.output_projection(y)
+        residual, skip = torch.chunk(y, 2, dim=1)
+
+        return (x + residual) / math.sqrt(2.0), skip
+
+
+class WaveNet(nn.Module):
+    """
+    WaveNet
+    https://www.deepmind.com/blog/wavenet-a-generative-model-for-raw-audio
+    """
+
+    def __init__(
+        self,
+        in_channels=128,
+        out_channels=128,
+        d_encoder=256,
+        residual_channels=512,
+        residual_layers=20,
+        use_linear_bias=False,
+        dilation_cycle=None,
+    ):
+        super(WaveNet, self).__init__()
+
+        self.input_projection = ConvNorm(in_channels, residual_channels, kernel_size=1)
+        self.diffusion_embedding = DiffusionEmbedding(residual_channels)
+        self.mlp = nn.Sequential(
+            LinearNorm(residual_channels, residual_channels * 4, use_linear_bias),
+            Mish(),
+            LinearNorm(residual_channels * 4, residual_channels, use_linear_bias),
+        )
+        self.residual_layers = nn.ModuleList(
+            [
+                ResidualBlock(
+                    d_encoder,
+                    residual_channels,
+                    use_linear_bias=use_linear_bias,
+                    dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1,
+                )
+                for i in range(residual_layers)
+            ]
+        )
+        self.skip_projection = ConvNorm(
+            residual_channels, residual_channels, kernel_size=1
+        )
+        self.output_projection = ConvNorm(
+            residual_channels, out_channels, kernel_size=1
+        )
+        nn.init.zeros_(self.output_projection.conv.weight)
+
+    def forward(self, x, diffusion_step, x_masks, condition):
+        x = self.input_projection(x)  # x [B, residual_channel, T]
+        x = F.relu(x)
+
+        diffusion_step = self.diffusion_embedding(diffusion_step)
+        diffusion_step = self.mlp(diffusion_step)
+
+        skip = []
+        for layer in self.residual_layers:
+            x, skip_connection = layer(x * x_masks, condition * x_masks, diffusion_step)
+            skip.append(skip_connection)
+
+        x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
+        x = self.skip_projection(x)
+        x = F.relu(x)
+        x = self.output_projection(x)  # [B, 128, T]
+
+        x = x * x_masks
+
+        return x