Explorar el Código

feat:修改裁减freqs的逻辑

zhaohaipeng hace 1 mes
padre
commit
e6f7231b19
Se han modificado 2 ficheros con 128 adiciones y 230 borrados
  1. 108 143
      fish_speech/models/dac/modded_dac.py
  2. 20 87
      fish_speech/models/text2semantic/llama.py

+ 108 - 143
fish_speech/models/dac/modded_dac.py

@@ -143,14 +143,14 @@ class Transformer(nn.Module):
         self.use_kv_cache = True
 
     def forward(
-            self,
-            x: Tensor,
-            input_pos: Optional[Tensor] = None,
-            mask: Optional[Tensor] = None,
+        self,
+        x: Tensor,
+        input_pos: Optional[Tensor] = None,
+        mask: Optional[Tensor] = None,
     ) -> Tensor:
         if self.config.pos_embed_type == "rope":
             assert (
-                    self.freqs_cis is not None
+                self.freqs_cis is not None
             ), "RoPE frequencies must be initialized for RoPE positional embedding"
             # print("MAX", input_pos.max())
             freqs_cis = self.freqs_cis[input_pos]
@@ -182,11 +182,11 @@ class TransformerBlock(nn.Module):
         self.ffn_layer_scale = LayerScale(config.dim, inplace=True)
 
     def forward(
-            self,
-            x: Tensor,
-            input_pos: Tensor,
-            freqs_cis: Tensor,
-            mask: Tensor,
+        self,
+        x: Tensor,
+        input_pos: Tensor,
+        freqs_cis: Tensor,
+        mask: Tensor,
     ) -> Tensor:
         h = x + self.attention_layer_scale(
             self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
@@ -241,16 +241,14 @@ class Attention(nn.Module):
         return rel_logits
 
     def forward(
-            self,
-            x: Tensor,
-            freqs_cis: Tensor,
-            mask: Tensor,
-            input_pos: Optional[Tensor] = None,
+        self,
+        x: Tensor,
+        freqs_cis: Tensor,
+        mask: Tensor,
+        input_pos: Optional[Tensor] = None,
     ) -> Tensor:
         bsz, seqlen, _ = x.shape
 
-        print(f"Attention forward self.n_local_heads {self.n_local_heads}, self.head_dim {self.head_dim}")
-
         kv_size = self.n_local_heads * self.head_dim
         q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
         context_seqlen = seqlen
@@ -259,48 +257,15 @@ class Attention(nn.Module):
         k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
         v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
 
-        # if self.pos_embed_type == "rope":
-        #     q = apply_rotary_emb(q, freqs_cis)
-        #     k = apply_rotary_emb(k, freqs_cis)
+        if self.pos_embed_type == "rope":
+            q = apply_rotary_emb(q, freqs_cis)
+            k = apply_rotary_emb(k, freqs_cis)
 
         q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
 
         if self.kv_cache is not None:
             k, v = self.kv_cache.update(input_pos, k, v)
 
-            if input_pos is not None:
-                # =========================
-                # 🔥 KV cache window 裁剪(核心优化)
-                # =========================
-                max_context = 4096  # ⭐ 推荐 4K 或 8K
-
-                # 当前有效长度
-                seq_len = int(input_pos.max().item()) + 1
-
-                # window 起点
-                start = max(0, seq_len - max_context)
-
-                # 裁剪 KV
-                k = k[:, :, start:seq_len, :]
-                v = v[:, :, start:seq_len, :]
-
-                # =========================
-                # 🔥 同步裁剪 mask(如果有)
-                # =========================
-                if mask is not None:
-                    mask = mask[:, :, :, start:seq_len]
-
-                # =========================
-                # 🔥 同步裁剪 RoPE(关键,不然会炸)
-                # =========================
-                print(f"input_pos.dtype {input_pos.dtype}")
-                assert input_pos.dtype == torch.long
-                freqs_cis = torch.index_select(freqs_cis, 0, input_pos.long())
-
-        if self.pos_embed_type == "rope":
-            q = apply_rotary_emb(q, freqs_cis)
-            k = apply_rotary_emb(k, freqs_cis)
-
         k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
         v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
 
@@ -368,10 +333,10 @@ class RMSNorm(nn.Module):
 
 class LayerScale(nn.Module):
     def __init__(
-            self,
-            dim: int,
-            init_values: Union[float, Tensor] = 1e-2,
-            inplace: bool = False,
+        self,
+        dim: int,
+        init_values: Union[float, Tensor] = 1e-2,
+        inplace: bool = False,
     ) -> None:
         super().__init__()
         self.inplace = inplace
@@ -387,12 +352,12 @@ class WindowLimitedTransformer(Transformer):
     """
 
     def __init__(
-            self,
-            config: ModelArgs,
-            input_dim: int = 512,
-            window_size: Optional[int] = None,
-            causal: bool = True,
-            look_ahead_conv: nn.Module = None,
+        self,
+        config: ModelArgs,
+        input_dim: int = 512,
+        window_size: Optional[int] = None,
+        causal: bool = True,
+        look_ahead_conv: nn.Module = None,
     ):
         super().__init__(config)
         self.window_size = window_size
@@ -413,9 +378,9 @@ class WindowLimitedTransformer(Transformer):
         )
 
     def make_window_limited_mask(
-            self,
-            max_length: int,
-            x_lens: Optional[Tensor] = None,
+        self,
+        max_length: int,
+        x_lens: Optional[Tensor] = None,
     ) -> Tensor:
         """
         Make mask to form window limited attention.
@@ -433,9 +398,9 @@ class WindowLimitedTransformer(Transformer):
         return mask
 
     def make_mask(
-            self,
-            max_length: int,
-            x_lens: Optional[Tensor] = None,
+        self,
+        max_length: int,
+        x_lens: Optional[Tensor] = None,
     ) -> Tensor:
         """
         Make ordinary mask if window size is not specified.
@@ -451,9 +416,9 @@ class WindowLimitedTransformer(Transformer):
         return mask
 
     def forward(
-            self,
-            x: Tensor,
-            x_lens: Optional[Tensor] = None,
+        self,
+        x: Tensor,
+        x_lens: Optional[Tensor] = None,
     ) -> Tensor:
         if self.channels_first:
             x = x.transpose(1, 2)
@@ -475,10 +440,10 @@ class WindowLimitedTransformer(Transformer):
 
 
 def precompute_freqs_cis(
-        seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16
+    seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16
 ) -> Tensor:
     freqs = 1.0 / (
-            base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
+        base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
     )
     t = torch.arange(seq_len, device=freqs.device)
     freqs = torch.outer(t, freqs)
@@ -518,7 +483,7 @@ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
 
 
 def get_extra_padding_for_conv1d(
-        x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
+    x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
 ) -> int:
     """See `pad_for_conv1d`."""
     length = x.shape[-1]
@@ -528,10 +493,10 @@ def get_extra_padding_for_conv1d(
 
 
 def pad1d(
-        x: torch.Tensor,
-        paddings: tp.Tuple[int, int],
-        mode: str = "zeros",
-        value: float = 0.0,
+    x: torch.Tensor,
+    paddings: tp.Tuple[int, int],
+    mode: str = "zeros",
+    value: float = 0.0,
 ):
     """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
     If this is the case, we insert extra 0 padding to the right
@@ -555,14 +520,14 @@ def pad1d(
 
 class CausalConvNet(nn.Module):
     def __init__(
-            self,
-            in_channels,
-            out_channels,
-            kernel_size,
-            dilation=1,
-            stride=1,
-            groups=1,
-            padding=None,
+        self,
+        in_channels,
+        out_channels,
+        kernel_size,
+        dilation=1,
+        stride=1,
+        groups=1,
+        padding=None,
     ):
         super(CausalConvNet, self).__init__()
         self.conv = nn.Conv1d(
@@ -597,7 +562,7 @@ class CausalConvNet(nn.Module):
 
 class CausalTransConvNet(nn.Module):
     def __init__(
-            self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
+        self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
     ):
         super(CausalTransConvNet, self).__init__()
         self.conv = nn.ConvTranspose1d(
@@ -651,18 +616,18 @@ class ResidualUnit(nn.Module):
             if self.causal:
                 x = x[..., :-pad]
             else:
-                x = x[..., pad // 2: -pad // 2]
+                x = x[..., pad // 2 : -pad // 2]
         return x + y
 
 
 class EncoderBlock(nn.Module):
     def __init__(
-            self,
-            dim: int = 16,
-            stride: int = 1,
-            causal: bool = False,
-            n_t_layer: int = 0,
-            transformer_general_config=None,
+        self,
+        dim: int = 16,
+        stride: int = 1,
+        causal: bool = False,
+        n_t_layer: int = 0,
+        transformer_general_config=None,
     ):
         super().__init__()
         conv_class = CausalWNConv1d if causal else WNConv1d
@@ -704,13 +669,13 @@ class EncoderBlock(nn.Module):
 
 class Encoder(nn.Module):
     def __init__(
-            self,
-            d_model: int = 64,
-            strides: list = [2, 4, 8, 8],
-            d_latent: int = 64,
-            n_transformer_layers: list = [0, 0, 4, 4],
-            transformer_general_config: ModelArgs = None,
-            causal: bool = False,
+        self,
+        d_model: int = 64,
+        strides: list = [2, 4, 8, 8],
+        d_latent: int = 64,
+        n_transformer_layers: list = [0, 0, 4, 4],
+        transformer_general_config: ModelArgs = None,
+        causal: bool = False,
     ):
         super().__init__()
         conv_class = CausalWNConv1d if causal else WNConv1d
@@ -746,13 +711,13 @@ class Encoder(nn.Module):
 
 class DecoderBlock(nn.Module):
     def __init__(
-            self,
-            input_dim: int = 16,
-            output_dim: int = 8,
-            stride: int = 1,
-            causal: bool = False,
-            n_t_layer: int = 0,
-            transformer_general_config=None,
+        self,
+        input_dim: int = 16,
+        output_dim: int = 8,
+        stride: int = 1,
+        causal: bool = False,
+        n_t_layer: int = 0,
+        transformer_general_config=None,
     ):
         super().__init__()
         conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d
@@ -794,14 +759,14 @@ class DecoderBlock(nn.Module):
 
 class Decoder(nn.Module):
     def __init__(
-            self,
-            input_channel,
-            channels,
-            rates,
-            d_out: int = 1,
-            causal: bool = False,
-            n_transformer_layers: list = [0, 0, 0, 0],
-            transformer_general_config=None,
+        self,
+        input_channel,
+        channels,
+        rates,
+        d_out: int = 1,
+        causal: bool = False,
+        n_transformer_layers: list = [0, 0, 0, 0],
+        transformer_general_config=None,
     ):
         super().__init__()
         conv_class = CausalWNConv1d if causal else WNConv1d
@@ -810,7 +775,7 @@ class Decoder(nn.Module):
 
         # Add upsampling + MRF blocks
         for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)):
-            input_dim = channels // 2 ** i
+            input_dim = channels // 2**i
             output_dim = channels // 2 ** (i + 1)
             layers += [
                 DecoderBlock(
@@ -838,19 +803,19 @@ class Decoder(nn.Module):
 
 class DAC(BaseModel, CodecMixin):
     def __init__(
-            self,
-            encoder_dim: int = 64,
-            encoder_rates: List[int] = [2, 4, 8, 8],
-            latent_dim: int = None,
-            decoder_dim: int = 1536,
-            decoder_rates: List[int] = [8, 8, 4, 2],
-            quantizer: torch.nn.Module = None,
-            sample_rate: int = 44100,
-            causal: bool = True,
-            encoder_transformer_layers: List[int] = [0, 0, 0, 0],
-            decoder_transformer_layers: List[int] = [0, 0, 0, 0],
-            overwrite_decoder: torch.nn.Module = None,
-            transformer_general_config=None,
+        self,
+        encoder_dim: int = 64,
+        encoder_rates: List[int] = [2, 4, 8, 8],
+        latent_dim: int = None,
+        decoder_dim: int = 1536,
+        decoder_rates: List[int] = [8, 8, 4, 2],
+        quantizer: torch.nn.Module = None,
+        sample_rate: int = 44100,
+        causal: bool = True,
+        encoder_transformer_layers: List[int] = [0, 0, 0, 0],
+        decoder_transformer_layers: List[int] = [0, 0, 0, 0],
+        overwrite_decoder: torch.nn.Module = None,
+        transformer_general_config=None,
     ):
         super().__init__()
 
@@ -907,11 +872,11 @@ class DAC(BaseModel, CodecMixin):
         return audio_data
 
     def encode(
-            self,
-            audio_data: torch.Tensor,
-            audio_lengths: torch.Tensor = None,
-            n_quantizers: int = None,
-            **kwargs,
+        self,
+        audio_data: torch.Tensor,
+        audio_lengths: torch.Tensor = None,
+        n_quantizers: int = None,
+        **kwargs,
     ):
         """Encode given audio data and return quantized latent codes
 
@@ -981,13 +946,13 @@ class DAC(BaseModel, CodecMixin):
         return self.decoder(z)
 
     def forward(
-            self,
-            audio_data: torch.Tensor,
-            template: torch.Tensor = None,
-            mask: torch.Tensor = None,
-            sample_rate: int = None,
-            n_quantizers: int = None,
-            **kwargs,
+        self,
+        audio_data: torch.Tensor,
+        template: torch.Tensor = None,
+        mask: torch.Tensor = None,
+        sample_rate: int = None,
+        n_quantizers: int = None,
+        **kwargs,
     ):
         """Model forward pass
 

+ 20 - 87
fish_speech/models/text2semantic/llama.py

@@ -887,20 +887,16 @@ class Attention(nn.Module):
             state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
 
     def forward(
-            self,
-            x: torch.Tensor,
-            freqs_cis: torch.Tensor,
-            mask: Optional[torch.Tensor],
-            input_pos: Optional[torch.Tensor] = None,
-    ):
+        self,
+        x: Tensor,
+        freqs_cis: Tensor,
+        mask: Tensor,
+        input_pos: Optional[Tensor] = None,
+    ) -> Tensor:
         bsz, seqlen, _ = x.shape
 
         q_size = self.n_head * self.head_dim
         kv_size = self.n_local_heads * self.head_dim
-
-        # =========================
-        # QKV projection
-        # =========================
         q, k, v = self.wqkv(x).split([q_size, kv_size, kv_size], dim=-1)
 
         q = q.view(bsz, seqlen, self.n_head, self.head_dim)
@@ -911,87 +907,28 @@ class Attention(nn.Module):
             q = self.q_norm(q)
             k = self.k_norm(k)
 
-        # [B, H, T, D]
-        q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+        q = apply_rotary_emb(q, freqs_cis)
+        k = apply_rotary_emb(k, freqs_cis)
 
-        # =========================
-        # KV Cache + Sliding Window
-        # =========================
-        start = 0
+        q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
 
         if self.kv_cache is not None:
-            # update cache
             k, v = self.kv_cache.update(input_pos, k, v)
 
-            max_context = 4096  # 可调
-
-            if input_pos is not None and seqlen == 1:
-                # decode
-                seq_len = int(input_pos.item()) + 1
-            else:
-                # prefill
-                seq_len = k.size(2)
-
-            start = max(0, seq_len - max_context)
-
-            # window 裁剪
-            k = k[:, :, start:seq_len, :]
-            v = v[:, :, start:seq_len, :]
+        k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+        v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
 
-            if mask is not None:
-                mask = mask[:, :, :, start:seq_len]
-        else:
-            seq_len = seqlen
-
-        # =========================
-        # RoPE(核心:以 K 为基准)
-        # =========================
-        T_k = k.size(2)
-
-        if self.kv_cache is not None:
-            freqs = freqs_cis[start: start + T_k]
-        else:
-            freqs = freqs_cis[:T_k]
-
-        # =========================
-        # Apply RoPE(Q/K 分开对齐)
-        # =========================
-        T_q = q.size(2)
-
-        print("Q:", q.shape)
-        print("K:", k.shape)
-        print("freqs:", freqs.shape)
-        print("start:", start)
-
-        assert k.size(2) == freqs.size(0), f"K vs freqs mismatch"
-        assert q.size(2) <= k.size(2), f"Q longer than K??"
-
-        # Q 用最后 T_q 个位置
-        q = apply_rotary_emb(q, freqs[-T_q:])
-
-        # K 用完整 window
-        k = apply_rotary_emb(k, freqs)
-
-        # =========================
-        # GQA expand
-        # =========================
-        if self.n_head != self.n_local_heads:
-            repeat = self.n_head // self.n_local_heads
-            k = k.repeat_interleave(repeat, dim=1)
-            v = v.repeat_interleave(repeat, dim=1)
-
-        # =========================
-        # Attention
-        # =========================
         if self.use_sdpa:
             if mask is None:
-                y = torch.nn.functional.scaled_dot_product_attention(
-                    q,
-                    k,
-                    v,
-                    dropout_p=self.dropout if self.training else 0.0,
-                    is_causal=True,
-                )
+                with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
+                    y = F.scaled_dot_product_attention(
+                        q,
+                        k,
+                        v,
+                        dropout_p=self.dropout if self.training else 0.0,
+                        is_causal=True,
+                        # No third party attn_mask here to use flash_attention
+                    )
             else:
                 y = torch.nn.functional.scaled_dot_product_attention(
                     q,
@@ -1009,14 +946,10 @@ class Attention(nn.Module):
                 dropout_p=self.dropout if self.training else 0.0,
             )
 
-        # =========================
-        # Output
-        # =========================
         y = y.transpose(1, 2).contiguous().view(bsz, seqlen, q_size)
 
         return self.wo(y)
 
-
     def eq_scaled_dot_product_attention(
         self,
         query,