zhaohaipeng 1 месяц назад
Родитель
Сommit
b4dfbee626
2 измененных файлов с 176 добавлено и 113 удалено
  1. 140 111
      fish_speech/models/dac/modded_dac.py
  2. 36 2
      fish_speech/models/text2semantic/llama.py

+ 140 - 111
fish_speech/models/dac/modded_dac.py

@@ -3,9 +3,7 @@ import typing as tp
 from dataclasses import dataclass
 from dataclasses import dataclass
 from typing import List, Optional, Union
 from typing import List, Optional, Union
 
 
-import numpy as np
 import torch
 import torch
-from audiotools import AudioSignal
 from audiotools.ml import BaseModel
 from audiotools.ml import BaseModel
 from dac.model.base import CodecMixin
 from dac.model.base import CodecMixin
 from dac.nn.layers import Snake1d, WNConv1d, WNConvTranspose1d
 from dac.nn.layers import Snake1d, WNConv1d, WNConvTranspose1d
@@ -64,7 +62,7 @@ class ModelArgs:
 
 
 class KVCache(nn.Module):
 class KVCache(nn.Module):
     def __init__(
     def __init__(
-        self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16
+            self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16
     ):
     ):
         super().__init__()
         super().__init__()
         cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
         cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
@@ -143,14 +141,14 @@ class Transformer(nn.Module):
         self.use_kv_cache = True
         self.use_kv_cache = True
 
 
     def forward(
     def forward(
-        self,
-        x: Tensor,
-        input_pos: Optional[Tensor] = None,
-        mask: Optional[Tensor] = None,
+            self,
+            x: Tensor,
+            input_pos: Optional[Tensor] = None,
+            mask: Optional[Tensor] = None,
     ) -> Tensor:
     ) -> Tensor:
         if self.config.pos_embed_type == "rope":
         if self.config.pos_embed_type == "rope":
             assert (
             assert (
-                self.freqs_cis is not None
+                    self.freqs_cis is not None
             ), "RoPE frequencies must be initialized for RoPE positional embedding"
             ), "RoPE frequencies must be initialized for RoPE positional embedding"
             # print("MAX", input_pos.max())
             # print("MAX", input_pos.max())
             freqs_cis = self.freqs_cis[input_pos]
             freqs_cis = self.freqs_cis[input_pos]
@@ -182,11 +180,11 @@ class TransformerBlock(nn.Module):
         self.ffn_layer_scale = LayerScale(config.dim, inplace=True)
         self.ffn_layer_scale = LayerScale(config.dim, inplace=True)
 
 
     def forward(
     def forward(
-        self,
-        x: Tensor,
-        input_pos: Tensor,
-        freqs_cis: Tensor,
-        mask: Tensor,
+            self,
+            x: Tensor,
+            input_pos: Tensor,
+            freqs_cis: Tensor,
+            mask: Tensor,
     ) -> Tensor:
     ) -> Tensor:
         h = x + self.attention_layer_scale(
         h = x + self.attention_layer_scale(
             self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
             self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
@@ -241,11 +239,11 @@ class Attention(nn.Module):
         return rel_logits
         return rel_logits
 
 
     def forward(
     def forward(
-        self,
-        x: Tensor,
-        freqs_cis: Tensor,
-        mask: Tensor,
-        input_pos: Optional[Tensor] = None,
+            self,
+            x: Tensor,
+            freqs_cis: Tensor,
+            mask: Tensor,
+            input_pos: Optional[Tensor] = None,
     ) -> Tensor:
     ) -> Tensor:
         bsz, seqlen, _ = x.shape
         bsz, seqlen, _ = x.shape
 
 
@@ -259,15 +257,46 @@ class Attention(nn.Module):
         k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
         k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
         v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
         v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
 
 
-        if self.pos_embed_type == "rope":
-            q = apply_rotary_emb(q, freqs_cis)
-            k = apply_rotary_emb(k, freqs_cis)
+        # if self.pos_embed_type == "rope":
+        #     q = apply_rotary_emb(q, freqs_cis)
+        #     k = apply_rotary_emb(k, freqs_cis)
 
 
         q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
         q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
 
 
         if self.kv_cache is not None:
         if self.kv_cache is not None:
             k, v = self.kv_cache.update(input_pos, k, v)
             k, v = self.kv_cache.update(input_pos, k, v)
 
 
+            if input_pos is not None:
+                # =========================
+                # 🔥 KV cache window 裁剪(核心优化)
+                # =========================
+                max_context = 4096  # ⭐ 推荐 4K 或 8K
+
+                # 当前有效长度
+                seq_len = int(input_pos.max().item()) + 1
+
+                # window 起点
+                start = max(0, seq_len - max_context)
+
+                # 裁剪 KV
+                k = k[:, :, start:seq_len, :]
+                v = v[:, :, start:seq_len, :]
+
+                # =========================
+                # 🔥 同步裁剪 mask(如果有)
+                # =========================
+                if mask is not None:
+                    mask = mask[:, :, :, start:seq_len]
+
+                # =========================
+                # 🔥 同步裁剪 RoPE(关键,不然会炸)
+                # =========================
+                freqs_cis = freqs_cis[start:seq_len]
+
+        if self.pos_embed_type == "rope":
+            q = apply_rotary_emb(q, freqs_cis)
+            k = apply_rotary_emb(k, freqs_cis)
+
         k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
         k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
         v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
         v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
 
 
@@ -335,10 +364,10 @@ class RMSNorm(nn.Module):
 
 
 class LayerScale(nn.Module):
 class LayerScale(nn.Module):
     def __init__(
     def __init__(
-        self,
-        dim: int,
-        init_values: Union[float, Tensor] = 1e-2,
-        inplace: bool = False,
+            self,
+            dim: int,
+            init_values: Union[float, Tensor] = 1e-2,
+            inplace: bool = False,
     ) -> None:
     ) -> None:
         super().__init__()
         super().__init__()
         self.inplace = inplace
         self.inplace = inplace
@@ -354,12 +383,12 @@ class WindowLimitedTransformer(Transformer):
     """
     """
 
 
     def __init__(
     def __init__(
-        self,
-        config: ModelArgs,
-        input_dim: int = 512,
-        window_size: Optional[int] = None,
-        causal: bool = True,
-        look_ahead_conv: nn.Module = None,
+            self,
+            config: ModelArgs,
+            input_dim: int = 512,
+            window_size: Optional[int] = None,
+            causal: bool = True,
+            look_ahead_conv: nn.Module = None,
     ):
     ):
         super().__init__(config)
         super().__init__(config)
         self.window_size = window_size
         self.window_size = window_size
@@ -380,9 +409,9 @@ class WindowLimitedTransformer(Transformer):
         )
         )
 
 
     def make_window_limited_mask(
     def make_window_limited_mask(
-        self,
-        max_length: int,
-        x_lens: Optional[Tensor] = None,
+            self,
+            max_length: int,
+            x_lens: Optional[Tensor] = None,
     ) -> Tensor:
     ) -> Tensor:
         """
         """
         Make mask to form window limited attention.
         Make mask to form window limited attention.
@@ -400,9 +429,9 @@ class WindowLimitedTransformer(Transformer):
         return mask
         return mask
 
 
     def make_mask(
     def make_mask(
-        self,
-        max_length: int,
-        x_lens: Optional[Tensor] = None,
+            self,
+            max_length: int,
+            x_lens: Optional[Tensor] = None,
     ) -> Tensor:
     ) -> Tensor:
         """
         """
         Make ordinary mask if window size is not specified.
         Make ordinary mask if window size is not specified.
@@ -418,9 +447,9 @@ class WindowLimitedTransformer(Transformer):
         return mask
         return mask
 
 
     def forward(
     def forward(
-        self,
-        x: Tensor,
-        x_lens: Optional[Tensor] = None,
+            self,
+            x: Tensor,
+            x_lens: Optional[Tensor] = None,
     ) -> Tensor:
     ) -> Tensor:
         if self.channels_first:
         if self.channels_first:
             x = x.transpose(1, 2)
             x = x.transpose(1, 2)
@@ -442,10 +471,10 @@ class WindowLimitedTransformer(Transformer):
 
 
 
 
 def precompute_freqs_cis(
 def precompute_freqs_cis(
-    seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16
+        seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16
 ) -> Tensor:
 ) -> Tensor:
     freqs = 1.0 / (
     freqs = 1.0 / (
-        base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
+            base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
     )
     )
     t = torch.arange(seq_len, device=freqs.device)
     t = torch.arange(seq_len, device=freqs.device)
     freqs = torch.outer(t, freqs)
     freqs = torch.outer(t, freqs)
@@ -485,7 +514,7 @@ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
 
 
 
 
 def get_extra_padding_for_conv1d(
 def get_extra_padding_for_conv1d(
-    x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
+        x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
 ) -> int:
 ) -> int:
     """See `pad_for_conv1d`."""
     """See `pad_for_conv1d`."""
     length = x.shape[-1]
     length = x.shape[-1]
@@ -495,10 +524,10 @@ def get_extra_padding_for_conv1d(
 
 
 
 
 def pad1d(
 def pad1d(
-    x: torch.Tensor,
-    paddings: tp.Tuple[int, int],
-    mode: str = "zeros",
-    value: float = 0.0,
+        x: torch.Tensor,
+        paddings: tp.Tuple[int, int],
+        mode: str = "zeros",
+        value: float = 0.0,
 ):
 ):
     """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
     """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
     If this is the case, we insert extra 0 padding to the right
     If this is the case, we insert extra 0 padding to the right
@@ -522,14 +551,14 @@ def pad1d(
 
 
 class CausalConvNet(nn.Module):
 class CausalConvNet(nn.Module):
     def __init__(
     def __init__(
-        self,
-        in_channels,
-        out_channels,
-        kernel_size,
-        dilation=1,
-        stride=1,
-        groups=1,
-        padding=None,
+            self,
+            in_channels,
+            out_channels,
+            kernel_size,
+            dilation=1,
+            stride=1,
+            groups=1,
+            padding=None,
     ):
     ):
         super(CausalConvNet, self).__init__()
         super(CausalConvNet, self).__init__()
         self.conv = nn.Conv1d(
         self.conv = nn.Conv1d(
@@ -564,7 +593,7 @@ class CausalConvNet(nn.Module):
 
 
 class CausalTransConvNet(nn.Module):
 class CausalTransConvNet(nn.Module):
     def __init__(
     def __init__(
-        self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
+            self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
     ):
     ):
         super(CausalTransConvNet, self).__init__()
         super(CausalTransConvNet, self).__init__()
         self.conv = nn.ConvTranspose1d(
         self.conv = nn.ConvTranspose1d(
@@ -618,18 +647,18 @@ class ResidualUnit(nn.Module):
             if self.causal:
             if self.causal:
                 x = x[..., :-pad]
                 x = x[..., :-pad]
             else:
             else:
-                x = x[..., pad // 2 : -pad // 2]
+                x = x[..., pad // 2: -pad // 2]
         return x + y
         return x + y
 
 
 
 
 class EncoderBlock(nn.Module):
 class EncoderBlock(nn.Module):
     def __init__(
     def __init__(
-        self,
-        dim: int = 16,
-        stride: int = 1,
-        causal: bool = False,
-        n_t_layer: int = 0,
-        transformer_general_config=None,
+            self,
+            dim: int = 16,
+            stride: int = 1,
+            causal: bool = False,
+            n_t_layer: int = 0,
+            transformer_general_config=None,
     ):
     ):
         super().__init__()
         super().__init__()
         conv_class = CausalWNConv1d if causal else WNConv1d
         conv_class = CausalWNConv1d if causal else WNConv1d
@@ -671,13 +700,13 @@ class EncoderBlock(nn.Module):
 
 
 class Encoder(nn.Module):
 class Encoder(nn.Module):
     def __init__(
     def __init__(
-        self,
-        d_model: int = 64,
-        strides: list = [2, 4, 8, 8],
-        d_latent: int = 64,
-        n_transformer_layers: list = [0, 0, 4, 4],
-        transformer_general_config: ModelArgs = None,
-        causal: bool = False,
+            self,
+            d_model: int = 64,
+            strides: list = [2, 4, 8, 8],
+            d_latent: int = 64,
+            n_transformer_layers: list = [0, 0, 4, 4],
+            transformer_general_config: ModelArgs = None,
+            causal: bool = False,
     ):
     ):
         super().__init__()
         super().__init__()
         conv_class = CausalWNConv1d if causal else WNConv1d
         conv_class = CausalWNConv1d if causal else WNConv1d
@@ -713,13 +742,13 @@ class Encoder(nn.Module):
 
 
 class DecoderBlock(nn.Module):
 class DecoderBlock(nn.Module):
     def __init__(
     def __init__(
-        self,
-        input_dim: int = 16,
-        output_dim: int = 8,
-        stride: int = 1,
-        causal: bool = False,
-        n_t_layer: int = 0,
-        transformer_general_config=None,
+            self,
+            input_dim: int = 16,
+            output_dim: int = 8,
+            stride: int = 1,
+            causal: bool = False,
+            n_t_layer: int = 0,
+            transformer_general_config=None,
     ):
     ):
         super().__init__()
         super().__init__()
         conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d
         conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d
@@ -761,14 +790,14 @@ class DecoderBlock(nn.Module):
 
 
 class Decoder(nn.Module):
 class Decoder(nn.Module):
     def __init__(
     def __init__(
-        self,
-        input_channel,
-        channels,
-        rates,
-        d_out: int = 1,
-        causal: bool = False,
-        n_transformer_layers: list = [0, 0, 0, 0],
-        transformer_general_config=None,
+            self,
+            input_channel,
+            channels,
+            rates,
+            d_out: int = 1,
+            causal: bool = False,
+            n_transformer_layers: list = [0, 0, 0, 0],
+            transformer_general_config=None,
     ):
     ):
         super().__init__()
         super().__init__()
         conv_class = CausalWNConv1d if causal else WNConv1d
         conv_class = CausalWNConv1d if causal else WNConv1d
@@ -777,7 +806,7 @@ class Decoder(nn.Module):
 
 
         # Add upsampling + MRF blocks
         # Add upsampling + MRF blocks
         for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)):
         for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)):
-            input_dim = channels // 2**i
+            input_dim = channels // 2 ** i
             output_dim = channels // 2 ** (i + 1)
             output_dim = channels // 2 ** (i + 1)
             layers += [
             layers += [
                 DecoderBlock(
                 DecoderBlock(
@@ -805,19 +834,19 @@ class Decoder(nn.Module):
 
 
 class DAC(BaseModel, CodecMixin):
 class DAC(BaseModel, CodecMixin):
     def __init__(
     def __init__(
-        self,
-        encoder_dim: int = 64,
-        encoder_rates: List[int] = [2, 4, 8, 8],
-        latent_dim: int = None,
-        decoder_dim: int = 1536,
-        decoder_rates: List[int] = [8, 8, 4, 2],
-        quantizer: torch.nn.Module = None,
-        sample_rate: int = 44100,
-        causal: bool = True,
-        encoder_transformer_layers: List[int] = [0, 0, 0, 0],
-        decoder_transformer_layers: List[int] = [0, 0, 0, 0],
-        overwrite_decoder: torch.nn.Module = None,
-        transformer_general_config=None,
+            self,
+            encoder_dim: int = 64,
+            encoder_rates: List[int] = [2, 4, 8, 8],
+            latent_dim: int = None,
+            decoder_dim: int = 1536,
+            decoder_rates: List[int] = [8, 8, 4, 2],
+            quantizer: torch.nn.Module = None,
+            sample_rate: int = 44100,
+            causal: bool = True,
+            encoder_transformer_layers: List[int] = [0, 0, 0, 0],
+            decoder_transformer_layers: List[int] = [0, 0, 0, 0],
+            overwrite_decoder: torch.nn.Module = None,
+            transformer_general_config=None,
     ):
     ):
         super().__init__()
         super().__init__()
 
 
@@ -874,11 +903,11 @@ class DAC(BaseModel, CodecMixin):
         return audio_data
         return audio_data
 
 
     def encode(
     def encode(
-        self,
-        audio_data: torch.Tensor,
-        audio_lengths: torch.Tensor = None,
-        n_quantizers: int = None,
-        **kwargs,
+            self,
+            audio_data: torch.Tensor,
+            audio_lengths: torch.Tensor = None,
+            n_quantizers: int = None,
+            **kwargs,
     ):
     ):
         """Encode given audio data and return quantized latent codes
         """Encode given audio data and return quantized latent codes
 
 
@@ -948,13 +977,13 @@ class DAC(BaseModel, CodecMixin):
         return self.decoder(z)
         return self.decoder(z)
 
 
     def forward(
     def forward(
-        self,
-        audio_data: torch.Tensor,
-        template: torch.Tensor = None,
-        mask: torch.Tensor = None,
-        sample_rate: int = None,
-        n_quantizers: int = None,
-        **kwargs,
+            self,
+            audio_data: torch.Tensor,
+            template: torch.Tensor = None,
+            mask: torch.Tensor = None,
+            sample_rate: int = None,
+            n_quantizers: int = None,
+            **kwargs,
     ):
     ):
         """Model forward pass
         """Model forward pass
 
 

+ 36 - 2
fish_speech/models/text2semantic/llama.py

@@ -315,6 +315,8 @@ class BaseTransformer(nn.Module):
         self.max_seq_len = max_seq_len
         self.max_seq_len = max_seq_len
         self.max_batch_size = max_batch_size
         self.max_batch_size = max_batch_size
 
 
+        logger.info(f"max_batch_size: {max_batch_size}, max_seq_len: {max_seq_len}")
+        logger.info(f"self.layers: {len(self.layers)}")
         for b in self.layers:
         for b in self.layers:
             b.attention.kv_cache = KVCache(
             b.attention.kv_cache = KVCache(
                 max_batch_size,
                 max_batch_size,
@@ -711,6 +713,8 @@ class DualARTransformer(BaseTransformer):
     ):
     ):
         super().setup_caches(max_batch_size, max_seq_len, dtype)
         super().setup_caches(max_batch_size, max_seq_len, dtype)
 
 
+        logger.info(f"max_batch_size: {max_batch_size}, max_seq_len: {max_seq_len}")
+        logger.info(f"self.fast_layers: {len(self.fast_layers)}")
         # Fast transformer
         # Fast transformer
         # The max seq len here is the number of codebooks
         # The max seq len here is the number of codebooks
         for b in self.fast_layers:
         for b in self.fast_layers:
@@ -903,14 +907,44 @@ class Attention(nn.Module):
             q = self.q_norm(q)
             q = self.q_norm(q)
             k = self.k_norm(k)
             k = self.k_norm(k)
 
 
-        q = apply_rotary_emb(q, freqs_cis)
-        k = apply_rotary_emb(k, freqs_cis)
+        # q = apply_rotary_emb(q, freqs_cis)
+        # k = apply_rotary_emb(k, freqs_cis)
 
 
         q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
         q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
 
 
         if self.kv_cache is not None:
         if self.kv_cache is not None:
             k, v = self.kv_cache.update(input_pos, k, v)
             k, v = self.kv_cache.update(input_pos, k, v)
 
 
+            if input_pos is not None:
+                # =========================
+                # 🔥 KV cache window 裁剪(核心优化)
+                # =========================
+                max_context = 4096  # ⭐ 推荐 4K 或 8K
+
+                # 当前有效长度
+                seq_len = int(input_pos.max().item()) + 1
+
+                # window 起点
+                start = max(0, seq_len - max_context)
+
+                # 裁剪 KV
+                k = k[:, :, start:seq_len, :]
+                v = v[:, :, start:seq_len, :]
+
+                # =========================
+                # 🔥 同步裁剪 mask(如果有)
+                # =========================
+                if mask is not None:
+                    mask = mask[:, :, :, start:seq_len]
+
+                # =========================
+                # 🔥 同步裁剪 RoPE(关键,不然会炸)
+                # =========================
+                freqs_cis = freqs_cis[start:seq_len]
+
+        q = apply_rotary_emb(q, freqs_cis)
+        k = apply_rotary_emb(k, freqs_cis)
+
         k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
         k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
         v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
         v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)