Sfoglia il codice sorgente

Add flash attention & gradient checkpointing

Lengyue 2 anni fa
parent
commit
dcc5e80ce2

+ 4 - 2
docs/en/index.md

@@ -21,8 +21,8 @@ Therefore, we strongly recommend Windows users to use WSL2 or docker to run the
 conda create -n fish-speech python=3.10
 conda activate fish-speech
 
-# Install pytorch
-pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
+# Install pytorch nightly
+pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
 
 # Install flash-attn (for Linux)
 pip3 install ninja && MAX_JOBS=4 pip3 install flash-attn --no-build-isolation
@@ -33,6 +33,7 @@ pip3 install -e .
 
 ## Changelog
 
+- 2023/12/27: Add `gradient checkpointing`, `causual sampling`, and `flash-attn` support.
 - 2023/12/19: Updated webui and HTTP API.
 - 2023/12/18: Updated fine-tuning documentation and related examples.
 - 2023/12/17: Updated `text2semantic` model, supporting phoneme-free mode.
@@ -44,3 +45,4 @@ pip3 install -e .
 - [GPT VITS](https://github.com/innnky/gpt-vits)
 - [MQTTS](https://github.com/b04901014/MQTTS)
 - [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [Transformers](https://github.com/huggingface/transformers)

+ 4 - 2
docs/zh/index.md

@@ -21,8 +21,8 @@
 conda create -n fish-speech python=3.10
 conda activate fish-speech
 
-# 安装 pytorch
-pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
+# 安装 pytorch nightly 版本
+pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
 
 # 安装 flash-attn (适用于linux)
 pip3 install ninja && MAX_JOBS=4 pip3 install flash-attn --no-build-isolation
@@ -33,6 +33,7 @@ pip3 install -e .
 
 ## 更新日志
 
+- 2023/12/27: 添加了 `gradient checkpointing`, `causual sampling` 和 `flash-attn` 支持.
 - 2023/12/19: 更新了 Webui 和 HTTP API.
 - 2023/12/18: 更新了微调文档和相关例子.
 - 2023/12/17: 更新了 `text2semantic` 模型, 支持无音素模式.
@@ -44,3 +45,4 @@ pip3 install -e .
 - [GPT VITS](https://github.com/innnky/gpt-vits)
 - [MQTTS](https://github.com/b04901014/MQTTS)
 - [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [Transformers](https://github.com/huggingface/transformers)

+ 7 - 1
fish_speech/models/text2semantic/lit_module.py

@@ -1,9 +1,15 @@
-from typing import Any
+import platform
+from typing import Any, Optional
 
 import lightning as L
+import torch
 import torch.nn.functional as F
 from lightning.pytorch.utilities.types import OptimizerLRScheduler
 
+import fish_speech.utils as utils
+
+log = utils.RankedLogger(__name__, rank_zero_only=True)
+
 
 class TextToSemantic(L.LightningModule):
     def __init__(self, model, optimizer: Any, lr_scheduler: Any):

+ 176 - 11
fish_speech/models/text2semantic/llama.py

@@ -6,6 +6,11 @@ import torch.nn as nn
 from einops import rearrange
 from torch import Tensor
 from torch.nn import functional as F
+from transformers.utils import is_flash_attn_2_available
+
+if is_flash_attn_2_available():
+    from flash_attn import flash_attn_func, flash_attn_varlen_func
+    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa
 
 
 def find_multiple(n: int, k: int) -> int:
@@ -32,6 +37,12 @@ class ModelArgs:
     num_codebooks: int = 4
     codebook_padding_idx: int = 0
 
+    # Use flash attention
+    use_flash_attention: bool = is_flash_attn_2_available()
+
+    # Gradient checkpointing
+    use_gradient_checkpointing: bool = True
+
     def __post_init__(self):
         if self.n_local_heads == -1:
             self.n_local_heads = self.n_head
@@ -154,7 +165,12 @@ class Transformer(nn.Module):
         input_pos: Optional[Tensor] = None,
     ) -> TransformerForwardResult:
         for layer in self.layers:
-            x = layer(x, freqs_cis, mask, input_pos=input_pos)
+            if self.config.use_gradient_checkpointing and self.training:
+                x = torch.utils.checkpoint.checkpoint(
+                    layer, x, freqs_cis, mask, input_pos, use_reentrant=True
+                )
+            else:
+                x = layer(x, freqs_cis, mask, input_pos=input_pos)
 
         x = self.norm(x)
         logits = self.output(x)
@@ -191,8 +207,10 @@ class Transformer(nn.Module):
         # Not that the causal mask here follows the definition of scaled_dot_product_attention
         # That is, FALSE means masked out
         # To maintain consistency, key_padding_mask use TRUE to mask out
-        if key_padding_mask is not None:
+        if self.config.use_flash_attention is False and key_padding_mask is not None:
             mask = mask & key_padding_mask[:, None, None, :].logical_not()
+        elif self.config.use_flash_attention is True and key_padding_mask is not None:
+            mask = key_padding_mask.logical_not()
 
         return self.compute(x, freqs_cis, mask)
 
@@ -246,6 +264,7 @@ class Attention(nn.Module):
         self.head_dim = config.head_dim
         self.n_local_heads = config.n_local_heads
         self.dim = config.dim
+        self.use_flash_attention = config.use_flash_attention
         self._register_load_state_dict_pre_hook(self.load_hook)
 
     def load_hook(self, state_dict, prefix, *args):
@@ -274,19 +293,165 @@ class Attention(nn.Module):
         q = apply_rotary_emb(q, freqs_cis)
         k = apply_rotary_emb(k, freqs_cis)
 
-        q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+        if self.use_flash_attention is False:
+            q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+
+            if self.kv_cache is not None:
+                k, v = self.kv_cache.update(input_pos, k, v)
+
+            k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+            v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+            y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+
+            y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
+        else:
+            assert (
+                self.kv_cache is None
+            ), "kv_cache is not supported for flash attention for now"
+
+            # We don't need to transpose q, k, v here because flash_attn_varlen_func
+            attn_output = self._flash_attention_forward(
+                q, k, v, mask, seqlen, dropout=0.0
+            )
+
+            y = attn_output.reshape(bsz, seqlen, self.dim).contiguous()
+
+        return self.wo(y)
+
+    def _flash_attention_forward(
+        self,
+        query_states,
+        key_states,
+        value_states,
+        attention_mask,
+        query_length,
+        dropout=0.0,
+        softmax_scale=None,
+    ):
+        """
+        Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+        first unpad the input, then computes the attention scores and pad the final attention scores.
+
+        Args:
+            query_states (`torch.Tensor`):
+                Input query states to be passed to Flash Attention API
+            key_states (`torch.Tensor`):
+                Input key states to be passed to Flash Attention API
+            value_states (`torch.Tensor`):
+                Input value states to be passed to Flash Attention API
+            attention_mask (`torch.Tensor`):
+                The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+                position of padding tokens and 1 for the position of non-padding tokens.
+            dropout (`int`, *optional*):
+                Attention dropout
+            softmax_scale (`float`, *optional*):
+                The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+        """
+
+        # Contains at least one padding token in the sequence
+        if attention_mask is not None:
+            batch_size = query_states.shape[0]
+            (
+                query_states,
+                key_states,
+                value_states,
+                indices_q,
+                cu_seq_lens,
+                max_seq_lens,
+            ) = self._upad_input(
+                query_states, key_states, value_states, attention_mask, query_length
+            )
+
+            cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+            max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+            attn_output_unpad = flash_attn_varlen_func(
+                query_states,
+                key_states,
+                value_states,
+                cu_seqlens_q=cu_seqlens_q,
+                cu_seqlens_k=cu_seqlens_k,
+                max_seqlen_q=max_seqlen_in_batch_q,
+                max_seqlen_k=max_seqlen_in_batch_k,
+                dropout_p=dropout,
+                softmax_scale=softmax_scale,
+                causal=True,
+            )
 
-        if self.kv_cache is not None:
-            k, v = self.kv_cache.update(input_pos, k, v)
+            attn_output = pad_input(
+                attn_output_unpad, indices_q, batch_size, query_length
+            )
+        else:
+            attn_output = flash_attn_func(
+                query_states,
+                key_states,
+                value_states,
+                dropout,
+                softmax_scale=softmax_scale,
+                causal=True,
+            )
 
-        k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
-        v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
-        y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+        return attn_output
 
-        y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
+    def _get_unpad_data(self, attention_mask):
+        seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+        indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+        max_seqlen_in_batch = seqlens_in_batch.max().item()
+        cu_seqlens = F.pad(
+            torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
+        )
+        return (
+            indices,
+            cu_seqlens,
+            max_seqlen_in_batch,
+        )
 
-        y = self.wo(y)
-        return y
+    def _upad_input(
+        self, query_layer, key_layer, value_layer, attention_mask, query_length
+    ):
+        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = self._get_unpad_data(
+            attention_mask
+        )
+        batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+        key_layer = index_first_axis(
+            key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
+            indices_k,
+        )
+        value_layer = index_first_axis(
+            value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
+            indices_k,
+        )
+        if query_length == kv_seq_len:
+            query_layer = index_first_axis(
+                query_layer.reshape(batch_size * kv_seq_len, self.n_head, head_dim),
+                indices_k,
+            )
+            cu_seqlens_q = cu_seqlens_k
+            max_seqlen_in_batch_q = max_seqlen_in_batch_k
+            indices_q = indices_k
+        elif query_length == 1:
+            max_seqlen_in_batch_q = 1
+            cu_seqlens_q = torch.arange(
+                batch_size + 1, dtype=torch.int32, device=query_layer.device
+            )  # There is a memcpy here, that is very bad.
+            indices_q = cu_seqlens_q[:-1]
+            query_layer = query_layer.squeeze(1)
+        else:
+            # The -q_len: slice assumes left padding.
+            attention_mask = attention_mask[:, -query_length:]
+            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
+                query_layer, attention_mask
+            )
+
+        return (
+            query_layer,
+            key_layer,
+            value_layer,
+            indices_q,
+            (cu_seqlens_q, cu_seqlens_k),
+            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+        )
 
 
 class FeedForward(nn.Module):

+ 0 - 4
fish_speech/train.py

@@ -68,10 +68,6 @@ def train(cfg: DictConfig) -> tuple[dict, dict]:
         log.info("Logging hyperparameters!")
         utils.log_hyperparameters(object_dict)
 
-    if cfg.get("compile"):
-        log.info("Compiling model!")
-        model = torch.compile(model)
-
     if cfg.get("train"):
         log.info("Starting training!")