Forráskód Böngészése

Remove flash attn deps as PyTorch ships in 2.2.0

Lengyue 2 éve
szülő
commit
cecee062c4
3 módosított fájl, 16 hozzáadás és 180 törlés
  1. 1 4
      docs/en/index.md
  2. 1 4
      docs/zh/index.md
  3. 14 172
      fish_speech/models/text2semantic/llama.py

+ 1 - 4
docs/en/index.md

@@ -11,7 +11,7 @@ This codebase is released under the `BSD-3-Clause` license, and all models are r
 
 ## Requirements
 - GPU Memory: 2GB (for inference), 16GB (for fine-tuning)
-- System: Linux (full functionality), Windows (inference only, no support for `flash-attn`, no support for `torch.compile`)
+- System: Linux (full functionality), Windows (inference only, no support for `torch.compile`)
 
 Therefore, we strongly recommend Windows users to use WSL2 or docker to run the codebase.
 
@@ -24,9 +24,6 @@ conda activate fish-speech
 # Install pytorch nightly
 pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
 
-# Install flash-attn (for Linux)
-pip3 install ninja && MAX_JOBS=4 pip3 install flash-attn --no-build-isolation
-
 # Install fish-speech
 pip3 install -e .
 ```

+ 1 - 4
docs/zh/index.md

@@ -11,7 +11,7 @@
 
 ## 要求
 - GPU内存: 2GB (用于推理), 16GB (用于微调)
-- 系统: Linux (全部功能), Windows (仅推理, 不支持 `flash-attn`, 不支持 `torch.compile`)
+- 系统: Linux (全部功能), Windows (仅推理, 不支持 `torch.compile`)
 
 因此, 我们强烈建议 Windows 用户使用 WSL2 或 docker 来运行代码库.
 
@@ -24,9 +24,6 @@ conda activate fish-speech
 # 安装 pytorch nightly 版本
 pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
 
-# 安装 flash-attn (适用于linux)
-pip3 install ninja && MAX_JOBS=4 pip3 install flash-attn --no-build-isolation
-
 # 安装 fish-speech
 pip3 install -e .
 ```

+ 14 - 172
fish_speech/models/text2semantic/llama.py

@@ -7,11 +7,6 @@ import torch.nn as nn
 from einops import rearrange
 from torch import Tensor
 from torch.nn import functional as F
-from transformers.utils import is_flash_attn_2_available
-
-if is_flash_attn_2_available():
-    from flash_attn import flash_attn_func, flash_attn_varlen_func
-    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa
 
 
 def find_multiple(n: int, k: int) -> int:
@@ -40,9 +35,6 @@ class ModelArgs:
     num_in_codebooks: Optional[int] = None
     codebook_padding_idx: int = 0
 
-    # Use flash attention
-    use_flash_attention: bool = False
-
     # Gradient checkpointing
     use_gradient_checkpointing: bool = True
 
@@ -225,10 +217,8 @@ class Transformer(nn.Module):
         # Not that the causal mask here follows the definition of scaled_dot_product_attention
         # That is, FALSE means masked out
         # To maintain consistency, key_padding_mask use TRUE to mask out
-        if self.config.use_flash_attention is False and key_padding_mask is not None:
+        if key_padding_mask is not None:
             mask = mask & key_padding_mask[:, None, None, :].logical_not()
-        elif self.config.use_flash_attention is True and key_padding_mask is not None:
-            mask = key_padding_mask.logical_not()
 
         return self.compute(x, freqs_cis, mask)
 
@@ -283,7 +273,6 @@ class Attention(nn.Module):
         self.head_dim = config.head_dim
         self.n_local_heads = config.n_local_heads
         self.dim = config.dim
-        self.use_flash_attention = config.use_flash_attention
         self._register_load_state_dict_pre_hook(self.load_hook)
 
     def load_hook(self, state_dict, prefix, *args):
@@ -312,171 +301,24 @@ class Attention(nn.Module):
         q = apply_rotary_emb(q, freqs_cis)
         k = apply_rotary_emb(k, freqs_cis)
 
-        if self.use_flash_attention is False:
-            q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
-
-            if self.kv_cache is not None:
-                k, v = self.kv_cache.update(input_pos, k, v)
-
-            k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
-            v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
-            y = F.scaled_dot_product_attention(
-                q,
-                k,
-                v,
-                attn_mask=mask,
-                dropout_p=self.dropout if self.training else 0.0,
-            )
-
-            y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
-        else:
-            assert (
-                self.kv_cache is None
-            ), "kv_cache is not supported for flash attention for now"
-
-            # We don't need to transpose q, k, v here because flash_attn_varlen_func
-            attn_output = self._flash_attention_forward(
-                q, k, v, mask, seqlen, dropout=self.dropout if self.training else 0.0
-            )
-
-            y = attn_output.reshape(bsz, seqlen, self.dim).contiguous()
-
-        return self.wo(y)
-
-    def _flash_attention_forward(
-        self,
-        query_states,
-        key_states,
-        value_states,
-        attention_mask,
-        query_length,
-        dropout=0.0,
-        softmax_scale=None,
-    ):
-        """
-        Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
-        first unpad the input, then computes the attention scores and pad the final attention scores.
-
-        Args:
-            query_states (`torch.Tensor`):
-                Input query states to be passed to Flash Attention API
-            key_states (`torch.Tensor`):
-                Input key states to be passed to Flash Attention API
-            value_states (`torch.Tensor`):
-                Input value states to be passed to Flash Attention API
-            attention_mask (`torch.Tensor`):
-                The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
-                position of padding tokens and 1 for the position of non-padding tokens.
-            dropout (`int`, *optional*):
-                Attention dropout
-            softmax_scale (`float`, *optional*):
-                The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
-        """
-
-        # Contains at least one padding token in the sequence
-        if attention_mask is not None:
-            batch_size = query_states.shape[0]
-            (
-                query_states,
-                key_states,
-                value_states,
-                indices_q,
-                cu_seq_lens,
-                max_seq_lens,
-            ) = self._upad_input(
-                query_states, key_states, value_states, attention_mask, query_length
-            )
-
-            cu_seqlens_q, cu_seqlens_k = cu_seq_lens
-            max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
-
-            attn_output_unpad = flash_attn_varlen_func(
-                query_states,
-                key_states,
-                value_states,
-                cu_seqlens_q=cu_seqlens_q,
-                cu_seqlens_k=cu_seqlens_k,
-                max_seqlen_q=max_seqlen_in_batch_q,
-                max_seqlen_k=max_seqlen_in_batch_k,
-                dropout_p=dropout,
-                softmax_scale=softmax_scale,
-                causal=True,
-            )
-
-            attn_output = pad_input(
-                attn_output_unpad, indices_q, batch_size, query_length
-            )
-        else:
-            attn_output = flash_attn_func(
-                query_states,
-                key_states,
-                value_states,
-                dropout,
-                softmax_scale=softmax_scale,
-                causal=True,
-            )
+        q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
 
-        return attn_output
+        if self.kv_cache is not None:
+            k, v = self.kv_cache.update(input_pos, k, v)
 
-    def _get_unpad_data(self, attention_mask):
-        seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
-        indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
-        max_seqlen_in_batch = seqlens_in_batch.max().item()
-        cu_seqlens = F.pad(
-            torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
-        )
-        return (
-            indices,
-            cu_seqlens,
-            max_seqlen_in_batch,
+        k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+        v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+        y = F.scaled_dot_product_attention(
+            q,
+            k,
+            v,
+            attn_mask=mask,
+            dropout_p=self.dropout if self.training else 0.0,
         )
 
-    def _upad_input(
-        self, query_layer, key_layer, value_layer, attention_mask, query_length
-    ):
-        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = self._get_unpad_data(
-            attention_mask
-        )
-        batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+        y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
 
-        key_layer = index_first_axis(
-            key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
-            indices_k,
-        )
-        value_layer = index_first_axis(
-            value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
-            indices_k,
-        )
-        if query_length == kv_seq_len:
-            query_layer = index_first_axis(
-                query_layer.reshape(batch_size * kv_seq_len, self.n_head, head_dim),
-                indices_k,
-            )
-            cu_seqlens_q = cu_seqlens_k
-            max_seqlen_in_batch_q = max_seqlen_in_batch_k
-            indices_q = indices_k
-        elif query_length == 1:
-            max_seqlen_in_batch_q = 1
-            cu_seqlens_q = torch.arange(
-                batch_size + 1, dtype=torch.int32, device=query_layer.device
-            )  # There is a memcpy here, that is very bad.
-            indices_q = cu_seqlens_q[:-1]
-            query_layer = query_layer.squeeze(1)
-        else:
-            # The -q_len: slice assumes left padding.
-            attention_mask = attention_mask[:, -query_length:]
-            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
-                query_layer, attention_mask
-            )
-
-        return (
-            query_layer,
-            key_layer,
-            value_layer,
-            indices_q,
-            (cu_seqlens_q, cu_seqlens_k),
-            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
-        )
+        return self.wo(y)
 
 
 class FeedForward(nn.Module):