zhaohaipeng пре 1 месец
родитељ
комит
b4dfbee626
2 измењених фајлова са 176 додато и 113 уклоњено
  1. 140 111
      fish_speech/models/dac/modded_dac.py
  2. 36 2
      fish_speech/models/text2semantic/llama.py

+ 140 - 111
fish_speech/models/dac/modded_dac.py

@@ -3,9 +3,7 @@ import typing as tp
 from dataclasses import dataclass
 from typing import List, Optional, Union
 
-import numpy as np
 import torch
-from audiotools import AudioSignal
 from audiotools.ml import BaseModel
 from dac.model.base import CodecMixin
 from dac.nn.layers import Snake1d, WNConv1d, WNConvTranspose1d
@@ -64,7 +62,7 @@ class ModelArgs:
 
 class KVCache(nn.Module):
     def __init__(
-        self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16
+            self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16
     ):
         super().__init__()
         cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
@@ -143,14 +141,14 @@ class Transformer(nn.Module):
         self.use_kv_cache = True
 
     def forward(
-        self,
-        x: Tensor,
-        input_pos: Optional[Tensor] = None,
-        mask: Optional[Tensor] = None,
+            self,
+            x: Tensor,
+            input_pos: Optional[Tensor] = None,
+            mask: Optional[Tensor] = None,
     ) -> Tensor:
         if self.config.pos_embed_type == "rope":
             assert (
-                self.freqs_cis is not None
+                    self.freqs_cis is not None
             ), "RoPE frequencies must be initialized for RoPE positional embedding"
             # print("MAX", input_pos.max())
             freqs_cis = self.freqs_cis[input_pos]
@@ -182,11 +180,11 @@ class TransformerBlock(nn.Module):
         self.ffn_layer_scale = LayerScale(config.dim, inplace=True)
 
     def forward(
-        self,
-        x: Tensor,
-        input_pos: Tensor,
-        freqs_cis: Tensor,
-        mask: Tensor,
+            self,
+            x: Tensor,
+            input_pos: Tensor,
+            freqs_cis: Tensor,
+            mask: Tensor,
     ) -> Tensor:
         h = x + self.attention_layer_scale(
             self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
@@ -241,11 +239,11 @@ class Attention(nn.Module):
         return rel_logits
 
     def forward(
-        self,
-        x: Tensor,
-        freqs_cis: Tensor,
-        mask: Tensor,
-        input_pos: Optional[Tensor] = None,
+            self,
+            x: Tensor,
+            freqs_cis: Tensor,
+            mask: Tensor,
+            input_pos: Optional[Tensor] = None,
     ) -> Tensor:
         bsz, seqlen, _ = x.shape
 
@@ -259,15 +257,46 @@ class Attention(nn.Module):
         k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
         v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
 
-        if self.pos_embed_type == "rope":
-            q = apply_rotary_emb(q, freqs_cis)
-            k = apply_rotary_emb(k, freqs_cis)
+        # if self.pos_embed_type == "rope":
+        #     q = apply_rotary_emb(q, freqs_cis)
+        #     k = apply_rotary_emb(k, freqs_cis)
 
         q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
 
         if self.kv_cache is not None:
             k, v = self.kv_cache.update(input_pos, k, v)
 
+            if input_pos is not None:
+                # =========================
+                # 🔥 KV cache window 裁剪(核心优化)
+                # =========================
+                max_context = 4096  # ⭐ 推荐 4K 或 8K
+
+                # 当前有效长度
+                seq_len = int(input_pos.max().item()) + 1
+
+                # window 起点
+                start = max(0, seq_len - max_context)
+
+                # 裁剪 KV
+                k = k[:, :, start:seq_len, :]
+                v = v[:, :, start:seq_len, :]
+
+                # =========================
+                # 🔥 同步裁剪 mask(如果有)
+                # =========================
+                if mask is not None:
+                    mask = mask[:, :, :, start:seq_len]
+
+                # =========================
+                # 🔥 同步裁剪 RoPE(关键,不然会炸)
+                # =========================
+                freqs_cis = freqs_cis[start:seq_len]
+
+        if self.pos_embed_type == "rope":
+            q = apply_rotary_emb(q, freqs_cis)
+            k = apply_rotary_emb(k, freqs_cis)
+
         k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
         v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
 
@@ -335,10 +364,10 @@ class RMSNorm(nn.Module):
 
 class LayerScale(nn.Module):
     def __init__(
-        self,
-        dim: int,
-        init_values: Union[float, Tensor] = 1e-2,
-        inplace: bool = False,
+            self,
+            dim: int,
+            init_values: Union[float, Tensor] = 1e-2,
+            inplace: bool = False,
     ) -> None:
         super().__init__()
         self.inplace = inplace
@@ -354,12 +383,12 @@ class WindowLimitedTransformer(Transformer):
     """
 
     def __init__(
-        self,
-        config: ModelArgs,
-        input_dim: int = 512,
-        window_size: Optional[int] = None,
-        causal: bool = True,
-        look_ahead_conv: nn.Module = None,
+            self,
+            config: ModelArgs,
+            input_dim: int = 512,
+            window_size: Optional[int] = None,
+            causal: bool = True,
+            look_ahead_conv: nn.Module = None,
     ):
         super().__init__(config)
         self.window_size = window_size
@@ -380,9 +409,9 @@ class WindowLimitedTransformer(Transformer):
         )
 
     def make_window_limited_mask(
-        self,
-        max_length: int,
-        x_lens: Optional[Tensor] = None,
+            self,
+            max_length: int,
+            x_lens: Optional[Tensor] = None,
     ) -> Tensor:
         """
         Make mask to form window limited attention.
@@ -400,9 +429,9 @@ class WindowLimitedTransformer(Transformer):
         return mask
 
     def make_mask(
-        self,
-        max_length: int,
-        x_lens: Optional[Tensor] = None,
+            self,
+            max_length: int,
+            x_lens: Optional[Tensor] = None,
     ) -> Tensor:
         """
         Make ordinary mask if window size is not specified.
@@ -418,9 +447,9 @@ class WindowLimitedTransformer(Transformer):
         return mask
 
     def forward(
-        self,
-        x: Tensor,
-        x_lens: Optional[Tensor] = None,
+            self,
+            x: Tensor,
+            x_lens: Optional[Tensor] = None,
     ) -> Tensor:
         if self.channels_first:
             x = x.transpose(1, 2)
@@ -442,10 +471,10 @@ class WindowLimitedTransformer(Transformer):
 
 
 def precompute_freqs_cis(
-    seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16
+        seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16
 ) -> Tensor:
     freqs = 1.0 / (
-        base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
+            base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
     )
     t = torch.arange(seq_len, device=freqs.device)
     freqs = torch.outer(t, freqs)
@@ -485,7 +514,7 @@ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
 
 
 def get_extra_padding_for_conv1d(
-    x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
+        x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
 ) -> int:
     """See `pad_for_conv1d`."""
     length = x.shape[-1]
@@ -495,10 +524,10 @@ def get_extra_padding_for_conv1d(
 
 
 def pad1d(
-    x: torch.Tensor,
-    paddings: tp.Tuple[int, int],
-    mode: str = "zeros",
-    value: float = 0.0,
+        x: torch.Tensor,
+        paddings: tp.Tuple[int, int],
+        mode: str = "zeros",
+        value: float = 0.0,
 ):
     """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
     If this is the case, we insert extra 0 padding to the right
@@ -522,14 +551,14 @@ def pad1d(
 
 class CausalConvNet(nn.Module):
     def __init__(
-        self,
-        in_channels,
-        out_channels,
-        kernel_size,
-        dilation=1,
-        stride=1,
-        groups=1,
-        padding=None,
+            self,
+            in_channels,
+            out_channels,
+            kernel_size,
+            dilation=1,
+            stride=1,
+            groups=1,
+            padding=None,
     ):
         super(CausalConvNet, self).__init__()
         self.conv = nn.Conv1d(
@@ -564,7 +593,7 @@ class CausalConvNet(nn.Module):
 
 class CausalTransConvNet(nn.Module):
     def __init__(
-        self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
+            self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
     ):
         super(CausalTransConvNet, self).__init__()
         self.conv = nn.ConvTranspose1d(
@@ -618,18 +647,18 @@ class ResidualUnit(nn.Module):
             if self.causal:
                 x = x[..., :-pad]
             else:
-                x = x[..., pad // 2 : -pad // 2]
+                x = x[..., pad // 2: -pad // 2]
         return x + y
 
 
 class EncoderBlock(nn.Module):
     def __init__(
-        self,
-        dim: int = 16,
-        stride: int = 1,
-        causal: bool = False,
-        n_t_layer: int = 0,
-        transformer_general_config=None,
+            self,
+            dim: int = 16,
+            stride: int = 1,
+            causal: bool = False,
+            n_t_layer: int = 0,
+            transformer_general_config=None,
     ):
         super().__init__()
         conv_class = CausalWNConv1d if causal else WNConv1d
@@ -671,13 +700,13 @@ class EncoderBlock(nn.Module):
 
 class Encoder(nn.Module):
     def __init__(
-        self,
-        d_model: int = 64,
-        strides: list = [2, 4, 8, 8],
-        d_latent: int = 64,
-        n_transformer_layers: list = [0, 0, 4, 4],
-        transformer_general_config: ModelArgs = None,
-        causal: bool = False,
+            self,
+            d_model: int = 64,
+            strides: list = [2, 4, 8, 8],
+            d_latent: int = 64,
+            n_transformer_layers: list = [0, 0, 4, 4],
+            transformer_general_config: ModelArgs = None,
+            causal: bool = False,
     ):
         super().__init__()
         conv_class = CausalWNConv1d if causal else WNConv1d
@@ -713,13 +742,13 @@ class Encoder(nn.Module):
 
 class DecoderBlock(nn.Module):
     def __init__(
-        self,
-        input_dim: int = 16,
-        output_dim: int = 8,
-        stride: int = 1,
-        causal: bool = False,
-        n_t_layer: int = 0,
-        transformer_general_config=None,
+            self,
+            input_dim: int = 16,
+            output_dim: int = 8,
+            stride: int = 1,
+            causal: bool = False,
+            n_t_layer: int = 0,
+            transformer_general_config=None,
     ):
         super().__init__()
         conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d
@@ -761,14 +790,14 @@ class DecoderBlock(nn.Module):
 
 class Decoder(nn.Module):
     def __init__(
-        self,
-        input_channel,
-        channels,
-        rates,
-        d_out: int = 1,
-        causal: bool = False,
-        n_transformer_layers: list = [0, 0, 0, 0],
-        transformer_general_config=None,
+            self,
+            input_channel,
+            channels,
+            rates,
+            d_out: int = 1,
+            causal: bool = False,
+            n_transformer_layers: list = [0, 0, 0, 0],
+            transformer_general_config=None,
     ):
         super().__init__()
         conv_class = CausalWNConv1d if causal else WNConv1d
@@ -777,7 +806,7 @@ class Decoder(nn.Module):
 
         # Add upsampling + MRF blocks
         for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)):
-            input_dim = channels // 2**i
+            input_dim = channels // 2 ** i
             output_dim = channels // 2 ** (i + 1)
             layers += [
                 DecoderBlock(
@@ -805,19 +834,19 @@ class Decoder(nn.Module):
 
 class DAC(BaseModel, CodecMixin):
     def __init__(
-        self,
-        encoder_dim: int = 64,
-        encoder_rates: List[int] = [2, 4, 8, 8],
-        latent_dim: int = None,
-        decoder_dim: int = 1536,
-        decoder_rates: List[int] = [8, 8, 4, 2],
-        quantizer: torch.nn.Module = None,
-        sample_rate: int = 44100,
-        causal: bool = True,
-        encoder_transformer_layers: List[int] = [0, 0, 0, 0],
-        decoder_transformer_layers: List[int] = [0, 0, 0, 0],
-        overwrite_decoder: torch.nn.Module = None,
-        transformer_general_config=None,
+            self,
+            encoder_dim: int = 64,
+            encoder_rates: List[int] = [2, 4, 8, 8],
+            latent_dim: int = None,
+            decoder_dim: int = 1536,
+            decoder_rates: List[int] = [8, 8, 4, 2],
+            quantizer: torch.nn.Module = None,
+            sample_rate: int = 44100,
+            causal: bool = True,
+            encoder_transformer_layers: List[int] = [0, 0, 0, 0],
+            decoder_transformer_layers: List[int] = [0, 0, 0, 0],
+            overwrite_decoder: torch.nn.Module = None,
+            transformer_general_config=None,
     ):
         super().__init__()
 
@@ -874,11 +903,11 @@ class DAC(BaseModel, CodecMixin):
         return audio_data
 
     def encode(
-        self,
-        audio_data: torch.Tensor,
-        audio_lengths: torch.Tensor = None,
-        n_quantizers: int = None,
-        **kwargs,
+            self,
+            audio_data: torch.Tensor,
+            audio_lengths: torch.Tensor = None,
+            n_quantizers: int = None,
+            **kwargs,
     ):
         """Encode given audio data and return quantized latent codes
 
@@ -948,13 +977,13 @@ class DAC(BaseModel, CodecMixin):
         return self.decoder(z)
 
     def forward(
-        self,
-        audio_data: torch.Tensor,
-        template: torch.Tensor = None,
-        mask: torch.Tensor = None,
-        sample_rate: int = None,
-        n_quantizers: int = None,
-        **kwargs,
+            self,
+            audio_data: torch.Tensor,
+            template: torch.Tensor = None,
+            mask: torch.Tensor = None,
+            sample_rate: int = None,
+            n_quantizers: int = None,
+            **kwargs,
     ):
         """Model forward pass
 

+ 36 - 2
fish_speech/models/text2semantic/llama.py

@@ -315,6 +315,8 @@ class BaseTransformer(nn.Module):
         self.max_seq_len = max_seq_len
         self.max_batch_size = max_batch_size
 
+        logger.info(f"max_batch_size: {max_batch_size}, max_seq_len: {max_seq_len}")
+        logger.info(f"self.layers: {len(self.layers)}")
         for b in self.layers:
             b.attention.kv_cache = KVCache(
                 max_batch_size,
@@ -711,6 +713,8 @@ class DualARTransformer(BaseTransformer):
     ):
         super().setup_caches(max_batch_size, max_seq_len, dtype)
 
+        logger.info(f"max_batch_size: {max_batch_size}, max_seq_len: {max_seq_len}")
+        logger.info(f"self.fast_layers: {len(self.fast_layers)}")
         # Fast transformer
         # The max seq len here is the number of codebooks
         for b in self.fast_layers:
@@ -903,14 +907,44 @@ class Attention(nn.Module):
             q = self.q_norm(q)
             k = self.k_norm(k)
 
-        q = apply_rotary_emb(q, freqs_cis)
-        k = apply_rotary_emb(k, freqs_cis)
+        # q = apply_rotary_emb(q, freqs_cis)
+        # k = apply_rotary_emb(k, freqs_cis)
 
         q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
 
         if self.kv_cache is not None:
             k, v = self.kv_cache.update(input_pos, k, v)
 
+            if input_pos is not None:
+                # =========================
+                # 🔥 KV cache window 裁剪(核心优化)
+                # =========================
+                max_context = 4096  # ⭐ 推荐 4K 或 8K
+
+                # 当前有效长度
+                seq_len = int(input_pos.max().item()) + 1
+
+                # window 起点
+                start = max(0, seq_len - max_context)
+
+                # 裁剪 KV
+                k = k[:, :, start:seq_len, :]
+                v = v[:, :, start:seq_len, :]
+
+                # =========================
+                # 🔥 同步裁剪 mask(如果有)
+                # =========================
+                if mask is not None:
+                    mask = mask[:, :, :, start:seq_len]
+
+                # =========================
+                # 🔥 同步裁剪 RoPE(关键,不然会炸)
+                # =========================
+                freqs_cis = freqs_cis[start:seq_len]
+
+        q = apply_rotary_emb(q, freqs_cis)
+        k = apply_rotary_emb(k, freqs_cis)
+
         k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
         v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)