Lengyue 2 лет назад
Родитель
Сommit
7a31b4043a

+ 5 - 10
fish_speech/configs/text2semantic_pretrain_small.yaml

@@ -2,9 +2,8 @@ defaults:
   - base
   - base
   - _self_
   - _self_
 
 
-project: text2semantic_pretrain_small_4_in_8_codebooks
+project: text2semantic_pretrain_small_dual_ar
 max_length: 2048
 max_length: 2048
-use_delay_pattern: true
 
 
 # Lightning Trainer
 # Lightning Trainer
 trainer:
 trainer:
@@ -29,7 +28,6 @@ train_dataset:
   use_speaker: false
   use_speaker: false
   phones_prob: 0.5
   phones_prob: 0.5
   interactive_prob: 0.5
   interactive_prob: 0.5
-  use_delay_pattern: ${use_delay_pattern}
 
 
 val_dataset:
 val_dataset:
   _target_: fish_speech.datasets.text.AutoAugTextDataset
   _target_: fish_speech.datasets.text.AutoAugTextDataset
@@ -39,7 +37,6 @@ val_dataset:
   use_speaker: false
   use_speaker: false
   phones_prob: 0.5
   phones_prob: 0.5
   interactive_prob: 0.5
   interactive_prob: 0.5
-  use_delay_pattern: ${use_delay_pattern}
 
 
 data:
 data:
   _target_: fish_speech.datasets.text.TextDataModule
   _target_: fish_speech.datasets.text.TextDataModule
@@ -59,18 +56,16 @@ model:
     _target_: fish_speech.models.text2semantic.llama.Transformer
     _target_: fish_speech.models.text2semantic.llama.Transformer
     config:
     config:
       _target_: fish_speech.models.text2semantic.llama.ModelArgs
       _target_: fish_speech.models.text2semantic.llama.ModelArgs
-      max_seq_len: 4096
+      max_seq_len: ${max_length}
       vocab_size: 36408
       vocab_size: 36408
-      n_layer: 12
+      n_slow_layer: 12
+      n_fast_layer: 4
       n_head: 12
       n_head: 12
       dim: 768
       dim: 768
       rope_base: 10000
       rope_base: 10000
       norm_eps: 1e-5
       norm_eps: 1e-5
-      num_in_codebooks: 4 # input codebook size
-      num_codebooks: 8  # output codebook size
+      num_codebooks: 8  # input/output codebook size
       codebook_size: 264 # codebook size 256 + 2 special tokens
       codebook_size: 264 # codebook size 256 + 2 special tokens
-      dropout: 0.1
-      neft_alpha: 10
 
 
   optimizer:
   optimizer:
     _target_: torch.optim.AdamW
     _target_: torch.optim.AdamW

+ 17 - 18
fish_speech/datasets/text.py

@@ -199,7 +199,6 @@ class AutoAugTextDataset(IterableDataset):
         mix_text_phone_prob: float = 0.5,
         mix_text_phone_prob: float = 0.5,
         use_negative_samples: bool = False,
         use_negative_samples: bool = False,
         num_codebooks: Optional[int] = None,
         num_codebooks: Optional[int] = None,
-        use_delay_pattern: bool = True,
     ):
     ):
         """
         """
         Args:
         Args:
@@ -217,7 +216,6 @@ class AutoAugTextDataset(IterableDataset):
             mix_text_phone_prob: probability to mix text and phones, if this is 0, then it will be pure text or pure phones
             mix_text_phone_prob: probability to mix text and phones, if this is 0, then it will be pure text or pure phones
             use_negative_samples: generate negative samples
             use_negative_samples: generate negative samples
             num_codebooks: number of codebooks, if None, it will be automatically detected
             num_codebooks: number of codebooks, if None, it will be automatically detected
-            use_delay_pattern: use delay pattern for codebooks
         """
         """
 
 
         super().__init__()
         super().__init__()
@@ -240,7 +238,8 @@ class AutoAugTextDataset(IterableDataset):
         self.mix_text_phone_prob = mix_text_phone_prob
         self.mix_text_phone_prob = mix_text_phone_prob
         self.use_negative_samples = use_negative_samples
         self.use_negative_samples = use_negative_samples
         self.num_codebooks = num_codebooks
         self.num_codebooks = num_codebooks
-        self.use_delay_pattern = use_delay_pattern
+
+        self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<s:0>")
 
 
         if use_data_server is True:
         if use_data_server is True:
             self.channel = grpc.insecure_channel(server)
             self.channel = grpc.insecure_channel(server)
@@ -497,12 +496,9 @@ class AutoAugTextDataset(IterableDataset):
         bos_bias = 1 if add_bos else 0
         bos_bias = 1 if add_bos else 0
 
 
         # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
         # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
-        pad_token_length = semantic_length + (
-            num_codebooks - 1 if self.use_delay_pattern else 0
-        )
         tokens = (
         tokens = (
             encoded
             encoded
-            + [self.tokenizer.pad_token_id] * pad_token_length
+            + [self.semantic_token_id] * semantic_length
             + [self.tokenizer.eos_token_id]
             + [self.tokenizer.eos_token_id]
         )
         )
 
 
@@ -510,10 +506,8 @@ class AutoAugTextDataset(IterableDataset):
             tokens = [self.tokenizer.bos_token_id] + tokens
             tokens = [self.tokenizer.bos_token_id] + tokens
 
 
         # Codebook bos/padding: 0, eos: 1
         # Codebook bos/padding: 0, eos: 1
-        # Implement delay pattern
         codes = [
         codes = [
-            [CODEBOOK_PAD_TOKEN_ID]
-            * (prompt_length + bos_bias + (i if self.use_delay_pattern else 0))
+            [CODEBOOK_PAD_TOKEN_ID] * (prompt_length + bos_bias)
             for i in range(num_codebooks)
             for i in range(num_codebooks)
         ]
         ]
         for segment in semantics:
         for segment in semantics:
@@ -522,8 +516,6 @@ class AutoAugTextDataset(IterableDataset):
                     codes[book_idx].append(int(j) + 2)
                     codes[book_idx].append(int(j) + 2)
 
 
         for idx, book in enumerate(codes):
         for idx, book in enumerate(codes):
-            if self.use_delay_pattern:
-                book.extend([CODEBOOK_PAD_TOKEN_ID] * (len(codes) - idx - 1))
             book.append(CODEBOOK_EOS_TOKEN_ID)
             book.append(CODEBOOK_EOS_TOKEN_ID)
 
 
         tokens = [tokens] + codes
         tokens = [tokens] + codes
@@ -577,10 +569,17 @@ class TextDataCollator:
 
 
     def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
     def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
         tokens, attention_masks, labels = [], [], []
         tokens, attention_masks, labels = [], [], []
+
+        # Calculate the max length
+        max_tokens_length = 0
+        for example in examples:
+            max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
+        max_tokens_length = min(max_tokens_length, self.max_length)
+
         for example in examples:
         for example in examples:
-            _tokens = example[tokens_key][:, : self.max_length]
-            _labels = example[labels_key][:, : self.max_length]
-            _attention_mask = torch.ones((self.max_length,), dtype=torch.bool)
+            _tokens = example[tokens_key][:, :max_tokens_length]
+            _labels = example[labels_key][:, :max_tokens_length]
+            _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
             tokens_length = _tokens.size(1)
             tokens_length = _tokens.size(1)
             _attention_mask[:tokens_length] = False
             _attention_mask[:tokens_length] = False
 
 
@@ -588,15 +587,15 @@ class TextDataCollator:
                 1
                 1
             ), f"{tokens_length} != {_labels.size(1)}"
             ), f"{tokens_length} != {_labels.size(1)}"
 
 
-            if tokens_length < self.max_length:
+            if tokens_length < max_tokens_length:
                 _tokens = F.pad(
                 _tokens = F.pad(
                     _tokens,
                     _tokens,
-                    (0, self.max_length - tokens_length),
+                    (0, max_tokens_length - tokens_length),
                     value=self.tokenizer.eos_token_id,
                     value=self.tokenizer.eos_token_id,
                 )
                 )
                 _tokens[1:, tokens_length:] = CODEBOOK_EOS_TOKEN_ID
                 _tokens[1:, tokens_length:] = CODEBOOK_EOS_TOKEN_ID
                 _labels = F.pad(
                 _labels = F.pad(
-                    _labels, (0, self.max_length - _labels.size(1)), value=-100
+                    _labels, (0, max_tokens_length - _labels.size(1)), value=-100
                 )
                 )
 
 
             tokens.append(_tokens)
             tokens.append(_tokens)

+ 7 - 23
fish_speech/models/text2semantic/lit_module.py

@@ -163,11 +163,11 @@ class TextToSemantic(L.LightningModule):
 
 
     def _step(self, batch, batch_idx, stage: str):
     def _step(self, batch, batch_idx, stage: str):
         # Do positive and negative samples in the same batch to speed up training
         # Do positive and negative samples in the same batch to speed up training
+        labels = batch["labels"]
         outputs = self.model(
         outputs = self.model(
             x=batch["inputs"],
             x=batch["inputs"],
             key_padding_mask=batch["attention_masks"],
             key_padding_mask=batch["attention_masks"],
         )
         )
-        labels = batch["labels"]
         token_logits = outputs.token_logits
         token_logits = outputs.token_logits
         codebook_logits = outputs.codebook_logits
         codebook_logits = outputs.codebook_logits
 
 
@@ -186,7 +186,12 @@ class TextToSemantic(L.LightningModule):
 
 
         # If we have a codebook, add the loss
         # If we have a codebook, add the loss
         if self.model.config.num_codebooks != 0:
         if self.model.config.num_codebooks != 0:
-            codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
+            # We want to shift the labels by one to the right
+            codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks, :-1]
+            codebook_labels = torch.nn.functional.pad(
+                codebook_labels, (0, 1), value=-100
+            ).mT
+
             semantic_loss = F.cross_entropy(
             semantic_loss = F.cross_entropy(
                 codebook_logits.reshape(-1, codebook_logits.size(-1)),
                 codebook_logits.reshape(-1, codebook_logits.size(-1)),
                 codebook_labels.reshape(-1),
                 codebook_labels.reshape(-1),
@@ -312,27 +317,6 @@ class TextToSemantic(L.LightningModule):
             logger=True,
             logger=True,
         )
         )
 
 
-        if self.model.config.num_codebooks != self.model.config.num_in_codebooks:
-            _, indices = codebook_logits[
-                :, :, : self.model.config.num_in_codebooks
-            ].topk(5, dim=-1)
-            codebook_labels = codebook_labels[
-                :, :, : self.model.config.num_in_codebooks
-            ]
-            correct = indices.eq(codebook_labels.unsqueeze(-1))
-            correct[codebook_labels == -100] = 0
-            correct = correct.sum()
-            accuracy = correct / (codebook_labels != -100).sum()
-
-            self.log(
-                f"{stage}/top_5_accuracy_in",
-                accuracy,
-                on_step=True,
-                on_epoch=False,
-                prog_bar=True,
-                logger=True,
-            )
-
         return loss
         return loss
 
 
     def training_step(self, batch, batch_idx):
     def training_step(self, batch, batch_idx):

+ 141 - 41
fish_speech/models/text2semantic/llama.py

@@ -7,6 +7,7 @@ import torch.nn as nn
 from einops import rearrange
 from einops import rearrange
 from torch import Tensor
 from torch import Tensor
 from torch.nn import functional as F
 from torch.nn import functional as F
+from torch.utils.checkpoint import checkpoint
 
 
 
 
 def find_multiple(n: int, k: int) -> int:
 def find_multiple(n: int, k: int) -> int:
@@ -18,7 +19,8 @@ def find_multiple(n: int, k: int) -> int:
 @dataclass
 @dataclass
 class ModelArgs:
 class ModelArgs:
     vocab_size: int = 32000
     vocab_size: int = 32000
-    n_layer: int = 32
+    n_slow_layer: int = 32
+    n_fast_layer: int = 4
     n_head: int = 32
     n_head: int = 32
     dim: int = 4096
     dim: int = 4096
     intermediate_size: int = None
     intermediate_size: int = None
@@ -32,15 +34,11 @@ class ModelArgs:
     # Additional decoding heads
     # Additional decoding heads
     codebook_size: int = 160
     codebook_size: int = 160
     num_codebooks: int = 4
     num_codebooks: int = 4
-    num_in_codebooks: Optional[int] = None
     codebook_padding_idx: int = 0
     codebook_padding_idx: int = 0
 
 
     # Gradient checkpointing
     # Gradient checkpointing
     use_gradient_checkpointing: bool = True
     use_gradient_checkpointing: bool = True
 
 
-    # NEFT
-    neft_alpha: float = 0
-
     def __post_init__(self):
     def __post_init__(self):
         if self.n_local_heads == -1:
         if self.n_local_heads == -1:
             self.n_local_heads = self.n_head
             self.n_local_heads = self.n_head
@@ -48,8 +46,6 @@ class ModelArgs:
             hidden_dim = 4 * self.dim
             hidden_dim = 4 * self.dim
             n_hidden = int(2 * hidden_dim / 3)
             n_hidden = int(2 * hidden_dim / 3)
             self.intermediate_size = find_multiple(n_hidden, 256)
             self.intermediate_size = find_multiple(n_hidden, 256)
-        if self.num_in_codebooks is None:
-            self.num_in_codebooks = self.num_codebooks
         self.head_dim = self.dim // self.n_head
         self.head_dim = self.dim // self.n_head
 
 
 
 
@@ -85,17 +81,34 @@ class Transformer(nn.Module):
         super().__init__()
         super().__init__()
         self.config = config
         self.config = config
 
 
+        # Slow transformer
         self.embeddings = nn.Embedding(
         self.embeddings = nn.Embedding(
-            config.vocab_size + config.codebook_size * config.num_in_codebooks,
+            config.vocab_size + config.codebook_size * config.num_codebooks,
             config.dim,
             config.dim,
         )
         )
-        self.layers = nn.ModuleList(
-            TransformerBlock(config) for _ in range(config.n_layer)
+        self.slow_layers = nn.ModuleList(
+            TransformerBlock(config, use_sdpa=True) for _ in range(config.n_slow_layer)
         )
         )
-        self.norm = RMSNorm(config.dim, eps=config.norm_eps)
-        self.output = nn.Linear(
+        self.slow_norm = RMSNorm(config.dim, eps=config.norm_eps)
+        self.slow_output = nn.Linear(
             config.dim,
             config.dim,
-            config.vocab_size + config.codebook_size * config.num_codebooks,
+            config.vocab_size,
+            bias=False,
+        )
+
+        # Fast transformer
+        self.fast_embeddings = nn.Embedding(
+            config.codebook_size, config.dim, padding_idx=config.codebook_padding_idx
+        )
+
+        # The equivalent bs is so large that sdpa doesn't work
+        self.fast_layers = nn.ModuleList(
+            TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer)
+        )
+        self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps)
+        self.fast_output = nn.Linear(
+            config.dim,
+            config.codebook_size,
             bias=False,
             bias=False,
         )
         )
 
 
@@ -106,6 +119,7 @@ class Transformer(nn.Module):
                 config.dim // config.n_head,
                 config.dim // config.n_head,
                 config.rope_base,
                 config.rope_base,
             ),
             ),
+            persistent=False,
         )
         )
         self.register_buffer(
         self.register_buffer(
             "causal_mask",
             "causal_mask",
@@ -116,6 +130,7 @@ class Transformer(nn.Module):
                     dtype=torch.bool,
                     dtype=torch.bool,
                 )
                 )
             ),
             ),
+            persistent=False,
         )
         )
 
 
         # For kv cache
         # For kv cache
@@ -144,11 +159,11 @@ class Transformer(nn.Module):
 
 
     def embed(self, x: Tensor) -> Tensor:
     def embed(self, x: Tensor) -> Tensor:
         # Here we want to merge the embeddings of the codebooks
         # Here we want to merge the embeddings of the codebooks
-        if self.config.num_in_codebooks == 0:
+        if self.config.num_codebooks == 0:
             return self.embeddings(x[:, 0])
             return self.embeddings(x[:, 0])
 
 
         vocab_embeds = [self.embeddings(x[:, 0])]
         vocab_embeds = [self.embeddings(x[:, 0])]
-        for i in range(self.config.num_in_codebooks):
+        for i in range(self.config.num_codebooks):
             emb = self.embeddings(
             emb = self.embeddings(
                 x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
                 x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
             )
             )
@@ -158,15 +173,6 @@ class Transformer(nn.Module):
         x = torch.stack(vocab_embeds, dim=3)
         x = torch.stack(vocab_embeds, dim=3)
         x = x.sum(dim=3)
         x = x.sum(dim=3)
 
 
-        if self.config.neft_alpha > 0 and self.training:
-            # alpha / sqrt(L * D)
-            scaled_alpha = self.config.neft_alpha / math.sqrt(
-                self.config.dim * x.shape[2]
-            )
-            x += torch.rand_like(x) * scaled_alpha
-
-            print("NEFT alpha:", scaled_alpha)
-
         return x
         return x
 
 
     def compute(
     def compute(
@@ -176,11 +182,11 @@ class Transformer(nn.Module):
         mask: Tensor,
         mask: Tensor,
         input_pos: Optional[Tensor] = None,
         input_pos: Optional[Tensor] = None,
     ) -> TransformerForwardResult:
     ) -> TransformerForwardResult:
+        raise NotImplementedError
+
         for layer in self.layers:
         for layer in self.layers:
             if self.config.use_gradient_checkpointing and self.training:
             if self.config.use_gradient_checkpointing and self.training:
-                x = torch.utils.checkpoint.checkpoint(
-                    layer, x, freqs_cis, mask, input_pos, use_reentrant=True
-                )
+                x = checkpoint(layer, x, freqs_cis, mask, input_pos, use_reentrant=True)
             else:
             else:
                 x = layer(x, freqs_cis, mask, input_pos=input_pos)
                 x = layer(x, freqs_cis, mask, input_pos=input_pos)
 
 
@@ -210,6 +216,10 @@ class Transformer(nn.Module):
         # x: (batch, num_codebooks + 1, seq_len)
         # x: (batch, num_codebooks + 1, seq_len)
         seq_len = x.size(2)
         seq_len = x.size(2)
 
 
+        # For codebook, the decoding is actually shifted by 1
+        # Which  is the labels section
+        codebooks = x[:, 1:]
+
         # Here we want to merge the embeddings of the codebooks
         # Here we want to merge the embeddings of the codebooks
         x = self.embed(x)
         x = self.embed(x)
 
 
@@ -222,9 +232,59 @@ class Transformer(nn.Module):
         if key_padding_mask is not None:
         if key_padding_mask is not None:
             mask = mask & key_padding_mask[:, None, None, :].logical_not()
             mask = mask & key_padding_mask[:, None, None, :].logical_not()
 
 
-        return self.compute(x, freqs_cis, mask)
+        for layer in self.slow_layers:
+            if self.config.use_gradient_checkpointing and self.training:
+                x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
+            else:
+                x = layer(x, freqs_cis, mask)
+
+        # We got slow_out here
+        slow_out = self.slow_norm(x)
+        token_logits = self.slow_output(slow_out)
+
+        # Fast transformer
+        fast_seq_len = self.config.num_codebooks
+        fast_mask = self.causal_mask[
+            None, None, :fast_seq_len, :fast_seq_len
+        ]  # (B, N, Q, K)
+        fast_freqs_cis = self.freqs_cis[:fast_seq_len]
+
+        # There should be a bug here
+        # Say at t0, the given input is [/INST] for semantic token
+        # Then we want to predict <tok0>, <tok1>, ... (instead of <s> <s> <s>) given <feat>, <tok0>, <tok1>, ...
+        # Otherwise this becomes: decode tokens from same given tokens
+        # Ignore the last token, since the input should be <feat>, <tok0>, <tok1>, ...
+        codebook_embeddings = self.fast_embeddings(codebooks[:, :-1])
+
+        x = torch.cat([x[:, None, 1:], codebook_embeddings], dim=1)  # (B, N + 1, S, D)
+        b, s = x.size(0), x.size(2)
+        x = rearrange(x, "b n s d -> (b s) n d")  # flatten the batch and seq_len
+
+        for layer in self.fast_layers:
+            if self.config.use_gradient_checkpointing and self.training:
+                x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
+            else:
+                x = layer(x, fast_freqs_cis, fast_mask)
+
+        # unflatten the batch and num_codebooks
+        fast_out = self.fast_norm(x)
+        codebook_logits = self.fast_output(fast_out)
+        assert codebook_logits.shape[1] == self.config.num_codebooks
+        codebook_logits = rearrange(
+            codebook_logits,
+            "(b s) n d -> b s n d",
+            b=b,
+            s=s,
+            n=self.config.num_codebooks,
+        )
+
+        return TransformerForwardResult(
+            token_logits=token_logits,
+            codebook_logits=codebook_logits,
+        )
 
 
     def forward_generate(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
     def forward_generate(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
+        ### TODO: fix this
         # x: (batch, num_codebooks + 1, 1)
         # x: (batch, num_codebooks + 1, 1)
 
 
         assert (
         assert (
@@ -244,9 +304,9 @@ class Transformer(nn.Module):
 
 
 
 
 class TransformerBlock(nn.Module):
 class TransformerBlock(nn.Module):
-    def __init__(self, config: ModelArgs) -> None:
+    def __init__(self, config: ModelArgs, use_sdpa: bool = True) -> None:
         super().__init__()
         super().__init__()
-        self.attention = Attention(config)
+        self.attention = Attention(config, use_sdpa=use_sdpa)
         self.feed_forward = FeedForward(config)
         self.feed_forward = FeedForward(config)
         self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
         self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
         self.attention_norm = RMSNorm(config.dim, config.norm_eps)
         self.attention_norm = RMSNorm(config.dim, config.norm_eps)
@@ -260,7 +320,7 @@ class TransformerBlock(nn.Module):
 
 
 
 
 class Attention(nn.Module):
 class Attention(nn.Module):
-    def __init__(self, config: ModelArgs):
+    def __init__(self, config: ModelArgs, use_sdpa: bool = True):
         super().__init__()
         super().__init__()
         assert config.dim % config.n_head == 0
         assert config.dim % config.n_head == 0
 
 
@@ -275,6 +335,7 @@ class Attention(nn.Module):
         self.head_dim = config.head_dim
         self.head_dim = config.head_dim
         self.n_local_heads = config.n_local_heads
         self.n_local_heads = config.n_local_heads
         self.dim = config.dim
         self.dim = config.dim
+        self.use_sdpa = use_sdpa
         self._register_load_state_dict_pre_hook(self.load_hook)
         self._register_load_state_dict_pre_hook(self.load_hook)
 
 
     def load_hook(self, state_dict, prefix, *args):
     def load_hook(self, state_dict, prefix, *args):
@@ -310,18 +371,56 @@ class Attention(nn.Module):
 
 
         k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
         k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
         v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
         v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
-        y = F.scaled_dot_product_attention(
-            q,
-            k,
-            v,
-            attn_mask=mask,
-            dropout_p=self.dropout if self.training else 0.0,
-        )
+
+        if self.use_sdpa:
+            y = F.scaled_dot_product_attention(
+                q,
+                k,
+                v,
+                attn_mask=mask,
+                dropout_p=self.dropout if self.training else 0.0,
+            )
+        else:
+            y = self.eq_scaled_dot_product_attention(
+                q,
+                k,
+                v,
+                attn_mask=mask,
+                dropout_p=self.dropout if self.training else 0.0,
+            )
 
 
         y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
         y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
 
 
         return self.wo(y)
         return self.wo(y)
 
 
+    def eq_scaled_dot_product_attention(
+        self,
+        query,
+        key,
+        value,
+        attn_mask=None,
+        dropout_p=0.0,
+    ) -> torch.Tensor:
+        # This is a standard scaled dot product attention
+        # It's low efficient, but it doesn't raise cuda error
+
+        L, S = query.size(-2), key.size(-2)
+        scale_factor = 1 / math.sqrt(query.size(-1))
+        attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
+
+        if attn_mask is not None:
+            if attn_mask.dtype == torch.bool:
+                attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
+            else:
+                attn_bias += attn_mask
+
+        attn_weight = query @ key.transpose(-2, -1) * scale_factor
+        attn_weight += attn_bias
+        attn_weight = torch.softmax(attn_weight, dim=-1)
+        attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
+
+        return attn_weight @ value
+
 
 
 class FeedForward(nn.Module):
 class FeedForward(nn.Module):
     def __init__(self, config: ModelArgs) -> None:
     def __init__(self, config: ModelArgs) -> None:
@@ -378,13 +477,14 @@ if __name__ == "__main__":
     args = ModelArgs(
     args = ModelArgs(
         max_seq_len=4096,
         max_seq_len=4096,
         vocab_size=32312,
         vocab_size=32312,
-        n_layer=12,
+        n_slow_layer=12,
+        n_fast_layer=4,
         n_head=12,
         n_head=12,
         dim=768,
         dim=768,
         rope_base=10000,
         rope_base=10000,
         norm_eps=1e-5,
         norm_eps=1e-5,
-        codebook_size=0,
-        num_codebooks=0,
+        codebook_size=128,
+        num_codebooks=4,
     )
     )
 
 
     model = Transformer(args)
     model = Transformer(args)