Jelajahi Sumber

feat:修改裁减freqs的逻辑

zhaohaipeng 1 bulan lalu
induk
melakukan
e6f7231b19
2 mengubah file dengan 128 tambahan dan 230 penghapusan
  1. 108 143
      fish_speech/models/dac/modded_dac.py
  2. 20 87
      fish_speech/models/text2semantic/llama.py

+ 108 - 143
fish_speech/models/dac/modded_dac.py

@@ -143,14 +143,14 @@ class Transformer(nn.Module):
         self.use_kv_cache = True
         self.use_kv_cache = True
 
 
     def forward(
     def forward(
-            self,
-            x: Tensor,
-            input_pos: Optional[Tensor] = None,
-            mask: Optional[Tensor] = None,
+        self,
+        x: Tensor,
+        input_pos: Optional[Tensor] = None,
+        mask: Optional[Tensor] = None,
     ) -> Tensor:
     ) -> Tensor:
         if self.config.pos_embed_type == "rope":
         if self.config.pos_embed_type == "rope":
             assert (
             assert (
-                    self.freqs_cis is not None
+                self.freqs_cis is not None
             ), "RoPE frequencies must be initialized for RoPE positional embedding"
             ), "RoPE frequencies must be initialized for RoPE positional embedding"
             # print("MAX", input_pos.max())
             # print("MAX", input_pos.max())
             freqs_cis = self.freqs_cis[input_pos]
             freqs_cis = self.freqs_cis[input_pos]
@@ -182,11 +182,11 @@ class TransformerBlock(nn.Module):
         self.ffn_layer_scale = LayerScale(config.dim, inplace=True)
         self.ffn_layer_scale = LayerScale(config.dim, inplace=True)
 
 
     def forward(
     def forward(
-            self,
-            x: Tensor,
-            input_pos: Tensor,
-            freqs_cis: Tensor,
-            mask: Tensor,
+        self,
+        x: Tensor,
+        input_pos: Tensor,
+        freqs_cis: Tensor,
+        mask: Tensor,
     ) -> Tensor:
     ) -> Tensor:
         h = x + self.attention_layer_scale(
         h = x + self.attention_layer_scale(
             self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
             self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
@@ -241,16 +241,14 @@ class Attention(nn.Module):
         return rel_logits
         return rel_logits
 
 
     def forward(
     def forward(
-            self,
-            x: Tensor,
-            freqs_cis: Tensor,
-            mask: Tensor,
-            input_pos: Optional[Tensor] = None,
+        self,
+        x: Tensor,
+        freqs_cis: Tensor,
+        mask: Tensor,
+        input_pos: Optional[Tensor] = None,
     ) -> Tensor:
     ) -> Tensor:
         bsz, seqlen, _ = x.shape
         bsz, seqlen, _ = x.shape
 
 
-        print(f"Attention forward self.n_local_heads {self.n_local_heads}, self.head_dim {self.head_dim}")
-
         kv_size = self.n_local_heads * self.head_dim
         kv_size = self.n_local_heads * self.head_dim
         q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
         q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
         context_seqlen = seqlen
         context_seqlen = seqlen
@@ -259,48 +257,15 @@ class Attention(nn.Module):
         k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
         k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
         v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
         v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
 
 
-        # if self.pos_embed_type == "rope":
-        #     q = apply_rotary_emb(q, freqs_cis)
-        #     k = apply_rotary_emb(k, freqs_cis)
+        if self.pos_embed_type == "rope":
+            q = apply_rotary_emb(q, freqs_cis)
+            k = apply_rotary_emb(k, freqs_cis)
 
 
         q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
         q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
 
 
         if self.kv_cache is not None:
         if self.kv_cache is not None:
             k, v = self.kv_cache.update(input_pos, k, v)
             k, v = self.kv_cache.update(input_pos, k, v)
 
 
-            if input_pos is not None:
-                # =========================
-                # 🔥 KV cache window 裁剪(核心优化)
-                # =========================
-                max_context = 4096  # ⭐ 推荐 4K 或 8K
-
-                # 当前有效长度
-                seq_len = int(input_pos.max().item()) + 1
-
-                # window 起点
-                start = max(0, seq_len - max_context)
-
-                # 裁剪 KV
-                k = k[:, :, start:seq_len, :]
-                v = v[:, :, start:seq_len, :]
-
-                # =========================
-                # 🔥 同步裁剪 mask(如果有)
-                # =========================
-                if mask is not None:
-                    mask = mask[:, :, :, start:seq_len]
-
-                # =========================
-                # 🔥 同步裁剪 RoPE(关键,不然会炸)
-                # =========================
-                print(f"input_pos.dtype {input_pos.dtype}")
-                assert input_pos.dtype == torch.long
-                freqs_cis = torch.index_select(freqs_cis, 0, input_pos.long())
-
-        if self.pos_embed_type == "rope":
-            q = apply_rotary_emb(q, freqs_cis)
-            k = apply_rotary_emb(k, freqs_cis)
-
         k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
         k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
         v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
         v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
 
 
@@ -368,10 +333,10 @@ class RMSNorm(nn.Module):
 
 
 class LayerScale(nn.Module):
 class LayerScale(nn.Module):
     def __init__(
     def __init__(
-            self,
-            dim: int,
-            init_values: Union[float, Tensor] = 1e-2,
-            inplace: bool = False,
+        self,
+        dim: int,
+        init_values: Union[float, Tensor] = 1e-2,
+        inplace: bool = False,
     ) -> None:
     ) -> None:
         super().__init__()
         super().__init__()
         self.inplace = inplace
         self.inplace = inplace
@@ -387,12 +352,12 @@ class WindowLimitedTransformer(Transformer):
     """
     """
 
 
     def __init__(
     def __init__(
-            self,
-            config: ModelArgs,
-            input_dim: int = 512,
-            window_size: Optional[int] = None,
-            causal: bool = True,
-            look_ahead_conv: nn.Module = None,
+        self,
+        config: ModelArgs,
+        input_dim: int = 512,
+        window_size: Optional[int] = None,
+        causal: bool = True,
+        look_ahead_conv: nn.Module = None,
     ):
     ):
         super().__init__(config)
         super().__init__(config)
         self.window_size = window_size
         self.window_size = window_size
@@ -413,9 +378,9 @@ class WindowLimitedTransformer(Transformer):
         )
         )
 
 
     def make_window_limited_mask(
     def make_window_limited_mask(
-            self,
-            max_length: int,
-            x_lens: Optional[Tensor] = None,
+        self,
+        max_length: int,
+        x_lens: Optional[Tensor] = None,
     ) -> Tensor:
     ) -> Tensor:
         """
         """
         Make mask to form window limited attention.
         Make mask to form window limited attention.
@@ -433,9 +398,9 @@ class WindowLimitedTransformer(Transformer):
         return mask
         return mask
 
 
     def make_mask(
     def make_mask(
-            self,
-            max_length: int,
-            x_lens: Optional[Tensor] = None,
+        self,
+        max_length: int,
+        x_lens: Optional[Tensor] = None,
     ) -> Tensor:
     ) -> Tensor:
         """
         """
         Make ordinary mask if window size is not specified.
         Make ordinary mask if window size is not specified.
@@ -451,9 +416,9 @@ class WindowLimitedTransformer(Transformer):
         return mask
         return mask
 
 
     def forward(
     def forward(
-            self,
-            x: Tensor,
-            x_lens: Optional[Tensor] = None,
+        self,
+        x: Tensor,
+        x_lens: Optional[Tensor] = None,
     ) -> Tensor:
     ) -> Tensor:
         if self.channels_first:
         if self.channels_first:
             x = x.transpose(1, 2)
             x = x.transpose(1, 2)
@@ -475,10 +440,10 @@ class WindowLimitedTransformer(Transformer):
 
 
 
 
 def precompute_freqs_cis(
 def precompute_freqs_cis(
-        seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16
+    seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16
 ) -> Tensor:
 ) -> Tensor:
     freqs = 1.0 / (
     freqs = 1.0 / (
-            base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
+        base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
     )
     )
     t = torch.arange(seq_len, device=freqs.device)
     t = torch.arange(seq_len, device=freqs.device)
     freqs = torch.outer(t, freqs)
     freqs = torch.outer(t, freqs)
@@ -518,7 +483,7 @@ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
 
 
 
 
 def get_extra_padding_for_conv1d(
 def get_extra_padding_for_conv1d(
-        x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
+    x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
 ) -> int:
 ) -> int:
     """See `pad_for_conv1d`."""
     """See `pad_for_conv1d`."""
     length = x.shape[-1]
     length = x.shape[-1]
@@ -528,10 +493,10 @@ def get_extra_padding_for_conv1d(
 
 
 
 
 def pad1d(
 def pad1d(
-        x: torch.Tensor,
-        paddings: tp.Tuple[int, int],
-        mode: str = "zeros",
-        value: float = 0.0,
+    x: torch.Tensor,
+    paddings: tp.Tuple[int, int],
+    mode: str = "zeros",
+    value: float = 0.0,
 ):
 ):
     """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
     """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
     If this is the case, we insert extra 0 padding to the right
     If this is the case, we insert extra 0 padding to the right
@@ -555,14 +520,14 @@ def pad1d(
 
 
 class CausalConvNet(nn.Module):
 class CausalConvNet(nn.Module):
     def __init__(
     def __init__(
-            self,
-            in_channels,
-            out_channels,
-            kernel_size,
-            dilation=1,
-            stride=1,
-            groups=1,
-            padding=None,
+        self,
+        in_channels,
+        out_channels,
+        kernel_size,
+        dilation=1,
+        stride=1,
+        groups=1,
+        padding=None,
     ):
     ):
         super(CausalConvNet, self).__init__()
         super(CausalConvNet, self).__init__()
         self.conv = nn.Conv1d(
         self.conv = nn.Conv1d(
@@ -597,7 +562,7 @@ class CausalConvNet(nn.Module):
 
 
 class CausalTransConvNet(nn.Module):
 class CausalTransConvNet(nn.Module):
     def __init__(
     def __init__(
-            self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
+        self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
     ):
     ):
         super(CausalTransConvNet, self).__init__()
         super(CausalTransConvNet, self).__init__()
         self.conv = nn.ConvTranspose1d(
         self.conv = nn.ConvTranspose1d(
@@ -651,18 +616,18 @@ class ResidualUnit(nn.Module):
             if self.causal:
             if self.causal:
                 x = x[..., :-pad]
                 x = x[..., :-pad]
             else:
             else:
-                x = x[..., pad // 2: -pad // 2]
+                x = x[..., pad // 2 : -pad // 2]
         return x + y
         return x + y
 
 
 
 
 class EncoderBlock(nn.Module):
 class EncoderBlock(nn.Module):
     def __init__(
     def __init__(
-            self,
-            dim: int = 16,
-            stride: int = 1,
-            causal: bool = False,
-            n_t_layer: int = 0,
-            transformer_general_config=None,
+        self,
+        dim: int = 16,
+        stride: int = 1,
+        causal: bool = False,
+        n_t_layer: int = 0,
+        transformer_general_config=None,
     ):
     ):
         super().__init__()
         super().__init__()
         conv_class = CausalWNConv1d if causal else WNConv1d
         conv_class = CausalWNConv1d if causal else WNConv1d
@@ -704,13 +669,13 @@ class EncoderBlock(nn.Module):
 
 
 class Encoder(nn.Module):
 class Encoder(nn.Module):
     def __init__(
     def __init__(
-            self,
-            d_model: int = 64,
-            strides: list = [2, 4, 8, 8],
-            d_latent: int = 64,
-            n_transformer_layers: list = [0, 0, 4, 4],
-            transformer_general_config: ModelArgs = None,
-            causal: bool = False,
+        self,
+        d_model: int = 64,
+        strides: list = [2, 4, 8, 8],
+        d_latent: int = 64,
+        n_transformer_layers: list = [0, 0, 4, 4],
+        transformer_general_config: ModelArgs = None,
+        causal: bool = False,
     ):
     ):
         super().__init__()
         super().__init__()
         conv_class = CausalWNConv1d if causal else WNConv1d
         conv_class = CausalWNConv1d if causal else WNConv1d
@@ -746,13 +711,13 @@ class Encoder(nn.Module):
 
 
 class DecoderBlock(nn.Module):
 class DecoderBlock(nn.Module):
     def __init__(
     def __init__(
-            self,
-            input_dim: int = 16,
-            output_dim: int = 8,
-            stride: int = 1,
-            causal: bool = False,
-            n_t_layer: int = 0,
-            transformer_general_config=None,
+        self,
+        input_dim: int = 16,
+        output_dim: int = 8,
+        stride: int = 1,
+        causal: bool = False,
+        n_t_layer: int = 0,
+        transformer_general_config=None,
     ):
     ):
         super().__init__()
         super().__init__()
         conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d
         conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d
@@ -794,14 +759,14 @@ class DecoderBlock(nn.Module):
 
 
 class Decoder(nn.Module):
 class Decoder(nn.Module):
     def __init__(
     def __init__(
-            self,
-            input_channel,
-            channels,
-            rates,
-            d_out: int = 1,
-            causal: bool = False,
-            n_transformer_layers: list = [0, 0, 0, 0],
-            transformer_general_config=None,
+        self,
+        input_channel,
+        channels,
+        rates,
+        d_out: int = 1,
+        causal: bool = False,
+        n_transformer_layers: list = [0, 0, 0, 0],
+        transformer_general_config=None,
     ):
     ):
         super().__init__()
         super().__init__()
         conv_class = CausalWNConv1d if causal else WNConv1d
         conv_class = CausalWNConv1d if causal else WNConv1d
@@ -810,7 +775,7 @@ class Decoder(nn.Module):
 
 
         # Add upsampling + MRF blocks
         # Add upsampling + MRF blocks
         for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)):
         for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)):
-            input_dim = channels // 2 ** i
+            input_dim = channels // 2**i
             output_dim = channels // 2 ** (i + 1)
             output_dim = channels // 2 ** (i + 1)
             layers += [
             layers += [
                 DecoderBlock(
                 DecoderBlock(
@@ -838,19 +803,19 @@ class Decoder(nn.Module):
 
 
 class DAC(BaseModel, CodecMixin):
 class DAC(BaseModel, CodecMixin):
     def __init__(
     def __init__(
-            self,
-            encoder_dim: int = 64,
-            encoder_rates: List[int] = [2, 4, 8, 8],
-            latent_dim: int = None,
-            decoder_dim: int = 1536,
-            decoder_rates: List[int] = [8, 8, 4, 2],
-            quantizer: torch.nn.Module = None,
-            sample_rate: int = 44100,
-            causal: bool = True,
-            encoder_transformer_layers: List[int] = [0, 0, 0, 0],
-            decoder_transformer_layers: List[int] = [0, 0, 0, 0],
-            overwrite_decoder: torch.nn.Module = None,
-            transformer_general_config=None,
+        self,
+        encoder_dim: int = 64,
+        encoder_rates: List[int] = [2, 4, 8, 8],
+        latent_dim: int = None,
+        decoder_dim: int = 1536,
+        decoder_rates: List[int] = [8, 8, 4, 2],
+        quantizer: torch.nn.Module = None,
+        sample_rate: int = 44100,
+        causal: bool = True,
+        encoder_transformer_layers: List[int] = [0, 0, 0, 0],
+        decoder_transformer_layers: List[int] = [0, 0, 0, 0],
+        overwrite_decoder: torch.nn.Module = None,
+        transformer_general_config=None,
     ):
     ):
         super().__init__()
         super().__init__()
 
 
@@ -907,11 +872,11 @@ class DAC(BaseModel, CodecMixin):
         return audio_data
         return audio_data
 
 
     def encode(
     def encode(
-            self,
-            audio_data: torch.Tensor,
-            audio_lengths: torch.Tensor = None,
-            n_quantizers: int = None,
-            **kwargs,
+        self,
+        audio_data: torch.Tensor,
+        audio_lengths: torch.Tensor = None,
+        n_quantizers: int = None,
+        **kwargs,
     ):
     ):
         """Encode given audio data and return quantized latent codes
         """Encode given audio data and return quantized latent codes
 
 
@@ -981,13 +946,13 @@ class DAC(BaseModel, CodecMixin):
         return self.decoder(z)
         return self.decoder(z)
 
 
     def forward(
     def forward(
-            self,
-            audio_data: torch.Tensor,
-            template: torch.Tensor = None,
-            mask: torch.Tensor = None,
-            sample_rate: int = None,
-            n_quantizers: int = None,
-            **kwargs,
+        self,
+        audio_data: torch.Tensor,
+        template: torch.Tensor = None,
+        mask: torch.Tensor = None,
+        sample_rate: int = None,
+        n_quantizers: int = None,
+        **kwargs,
     ):
     ):
         """Model forward pass
         """Model forward pass
 
 

+ 20 - 87
fish_speech/models/text2semantic/llama.py

@@ -887,20 +887,16 @@ class Attention(nn.Module):
             state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
             state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
 
 
     def forward(
     def forward(
-            self,
-            x: torch.Tensor,
-            freqs_cis: torch.Tensor,
-            mask: Optional[torch.Tensor],
-            input_pos: Optional[torch.Tensor] = None,
-    ):
+        self,
+        x: Tensor,
+        freqs_cis: Tensor,
+        mask: Tensor,
+        input_pos: Optional[Tensor] = None,
+    ) -> Tensor:
         bsz, seqlen, _ = x.shape
         bsz, seqlen, _ = x.shape
 
 
         q_size = self.n_head * self.head_dim
         q_size = self.n_head * self.head_dim
         kv_size = self.n_local_heads * self.head_dim
         kv_size = self.n_local_heads * self.head_dim
-
-        # =========================
-        # QKV projection
-        # =========================
         q, k, v = self.wqkv(x).split([q_size, kv_size, kv_size], dim=-1)
         q, k, v = self.wqkv(x).split([q_size, kv_size, kv_size], dim=-1)
 
 
         q = q.view(bsz, seqlen, self.n_head, self.head_dim)
         q = q.view(bsz, seqlen, self.n_head, self.head_dim)
@@ -911,87 +907,28 @@ class Attention(nn.Module):
             q = self.q_norm(q)
             q = self.q_norm(q)
             k = self.k_norm(k)
             k = self.k_norm(k)
 
 
-        # [B, H, T, D]
-        q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+        q = apply_rotary_emb(q, freqs_cis)
+        k = apply_rotary_emb(k, freqs_cis)
 
 
-        # =========================
-        # KV Cache + Sliding Window
-        # =========================
-        start = 0
+        q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
 
 
         if self.kv_cache is not None:
         if self.kv_cache is not None:
-            # update cache
             k, v = self.kv_cache.update(input_pos, k, v)
             k, v = self.kv_cache.update(input_pos, k, v)
 
 
-            max_context = 4096  # 可调
-
-            if input_pos is not None and seqlen == 1:
-                # decode
-                seq_len = int(input_pos.item()) + 1
-            else:
-                # prefill
-                seq_len = k.size(2)
-
-            start = max(0, seq_len - max_context)
-
-            # window 裁剪
-            k = k[:, :, start:seq_len, :]
-            v = v[:, :, start:seq_len, :]
+        k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+        v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
 
 
-            if mask is not None:
-                mask = mask[:, :, :, start:seq_len]
-        else:
-            seq_len = seqlen
-
-        # =========================
-        # RoPE(核心:以 K 为基准)
-        # =========================
-        T_k = k.size(2)
-
-        if self.kv_cache is not None:
-            freqs = freqs_cis[start: start + T_k]
-        else:
-            freqs = freqs_cis[:T_k]
-
-        # =========================
-        # Apply RoPE(Q/K 分开对齐)
-        # =========================
-        T_q = q.size(2)
-
-        print("Q:", q.shape)
-        print("K:", k.shape)
-        print("freqs:", freqs.shape)
-        print("start:", start)
-
-        assert k.size(2) == freqs.size(0), f"K vs freqs mismatch"
-        assert q.size(2) <= k.size(2), f"Q longer than K??"
-
-        # Q 用最后 T_q 个位置
-        q = apply_rotary_emb(q, freqs[-T_q:])
-
-        # K 用完整 window
-        k = apply_rotary_emb(k, freqs)
-
-        # =========================
-        # GQA expand
-        # =========================
-        if self.n_head != self.n_local_heads:
-            repeat = self.n_head // self.n_local_heads
-            k = k.repeat_interleave(repeat, dim=1)
-            v = v.repeat_interleave(repeat, dim=1)
-
-        # =========================
-        # Attention
-        # =========================
         if self.use_sdpa:
         if self.use_sdpa:
             if mask is None:
             if mask is None:
-                y = torch.nn.functional.scaled_dot_product_attention(
-                    q,
-                    k,
-                    v,
-                    dropout_p=self.dropout if self.training else 0.0,
-                    is_causal=True,
-                )
+                with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
+                    y = F.scaled_dot_product_attention(
+                        q,
+                        k,
+                        v,
+                        dropout_p=self.dropout if self.training else 0.0,
+                        is_causal=True,
+                        # No third party attn_mask here to use flash_attention
+                    )
             else:
             else:
                 y = torch.nn.functional.scaled_dot_product_attention(
                 y = torch.nn.functional.scaled_dot_product_attention(
                     q,
                     q,
@@ -1009,14 +946,10 @@ class Attention(nn.Module):
                 dropout_p=self.dropout if self.training else 0.0,
                 dropout_p=self.dropout if self.training else 0.0,
             )
             )
 
 
-        # =========================
-        # Output
-        # =========================
         y = y.transpose(1, 2).contiguous().view(bsz, seqlen, q_size)
         y = y.transpose(1, 2).contiguous().view(bsz, seqlen, q_size)
 
 
         return self.wo(y)
         return self.wo(y)
 
 
-
     def eq_scaled_dot_product_attention(
     def eq_scaled_dot_product_attention(
         self,
         self,
         query,
         query,