Explorar o código

Make AR and naive decoder configurable

Lengyue %!s(int64=2) %!d(string=hai) anos
pai
achega
26af12b8c6

+ 0 - 14
fish_speech/configs/text2semantic_pretrain_large.yaml

@@ -1,14 +0,0 @@
-defaults:
-  - text2semantic_pretrain_small
-  - _self_
-
-project: text2semantic_pretrain_large_dual_ar
-
-# Model Configuration
-model:
-  model:
-    config:
-      n_slow_layer: 36
-      n_fast_layer: 8
-      n_head: 20
-      dim: 1280

+ 1 - 1
fish_speech/configs/text2semantic_pretrain_medium.yaml

@@ -8,7 +8,7 @@ project: text2semantic_pretrain_medium_dual_ar
 model:
   model:
     config:
-      n_slow_layer: 24
+      n_layer: 24
       n_fast_layer: 6
       n_head: 16
       dim: 1024

+ 3 - 3
fish_speech/configs/text2semantic_pretrain_small.yaml

@@ -53,12 +53,12 @@ model:
 
   model:
     # ~ 130M parameters, for debug purpose
-    _target_: fish_speech.models.text2semantic.llama.Transformer
+    _target_: fish_speech.models.text2semantic.llama.DualARTransformer
     config:
-      _target_: fish_speech.models.text2semantic.llama.ModelArgs
+      _target_: fish_speech.models.text2semantic.llama.DualARModelArgs
       max_seq_len: ${max_length}
       vocab_size: 36408
-      n_slow_layer: 12
+      n_layer: 12
       n_fast_layer: 4
       n_head: 12
       dim: 768

+ 50 - 43
fish_speech/models/text2semantic/lit_module.py

@@ -8,7 +8,7 @@ import torch.nn.functional as F
 from lightning.pytorch.utilities.types import OptimizerLRScheduler
 
 import fish_speech.utils as utils
-from fish_speech.models.text2semantic.llama import Transformer
+from fish_speech.models.text2semantic.llama import NaiveTransformer
 
 log = utils.RankedLogger(__name__, rank_zero_only=True)
 
@@ -23,7 +23,7 @@ class LoraConfig:
 class TextToSemantic(L.LightningModule):
     def __init__(
         self,
-        model: Transformer,
+        model: NaiveTransformer,
         optimizer: Any,
         lr_scheduler: Any,
         lora_config: Optional[LoraConfig] = None,
@@ -184,18 +184,14 @@ class TextToSemantic(L.LightningModule):
             ignore_index=-100,
         )
 
-        # If we have a codebook, add the loss
-        if self.model.config.num_codebooks != 0:
-            codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
-            semantic_loss = F.cross_entropy(
-                codebook_logits.reshape(-1, codebook_logits.size(-1)),
-                codebook_labels.reshape(-1),
-                ignore_index=-100,
-            )
+        codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
+        semantic_loss = F.cross_entropy(
+            codebook_logits.reshape(-1, codebook_logits.size(-1)),
+            codebook_labels.reshape(-1),
+            ignore_index=-100,
+        )
 
-            loss = base_loss + semantic_loss
-        else:
-            loss = base_loss
+        loss = base_loss + semantic_loss
 
         # If we use dpo
         if self.use_dpo:
@@ -270,39 +266,26 @@ class TextToSemantic(L.LightningModule):
             logger=True,
         )
 
-        if self.model.config.num_codebooks != 0:
-            self.log(
-                f"{stage}/base_loss",
-                base_loss,
-                on_step=True,
-                on_epoch=False,
-                prog_bar=False,
-                logger=True,
-            )
+        self.log(
+            f"{stage}/base_loss",
+            base_loss,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=False,
+            logger=True,
+        )
 
-            self.log(
-                f"{stage}/semantic_loss",
-                semantic_loss,
-                on_step=True,
-                on_epoch=False,
-                prog_bar=False,
-                logger=True,
-            )
+        self.log(
+            f"{stage}/semantic_loss",
+            semantic_loss,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=False,
+            logger=True,
+        )
 
         # Top-5 accuracy
-        if self.model.config.num_codebooks == 0:
-            _, indices = token_logits.topk(5, dim=-1)
-            correct = indices.eq(labels[:, 0].unsqueeze(-1))
-            correct[labels[:, 0] == -100] = 0
-            correct = correct.sum()
-            accuracy = correct / (labels[:, 0] != -100).sum()
-        else:
-            _, indices = codebook_logits.topk(5, dim=-1)
-            correct = indices.eq(codebook_labels.unsqueeze(-1))
-            correct[codebook_labels == -100] = 0
-            correct = correct.sum()
-            accuracy = correct / (codebook_labels != -100).sum()
-
+        accuracy = self.get_accuracy(codebook_logits, codebook_labels)
         self.log(
             f"{stage}/top_5_accuracy",
             accuracy,
@@ -312,8 +295,32 @@ class TextToSemantic(L.LightningModule):
             logger=True,
         )
 
+        if self.model.config.num_codebooks != self.model.config.num_in_codebooks:
+            accuracy = self.get_accuracy(
+                codebook_logits[:, :, : self.model.config.num_in_codebooks],
+                codebook_labels[:, :, : self.model.config.num_in_codebooks],
+            )
+
+            self.log(
+                f"{stage}/top_5_accuracy_in",
+                accuracy,
+                on_step=True,
+                on_epoch=False,
+                prog_bar=True,
+                logger=True,
+            )
+
         return loss
 
+    def get_accuracy(self, logits, labels):
+        _, indices = logits.topk(5, dim=-1)
+        correct = indices.eq(labels.unsqueeze(-1))
+        correct[labels == -100] = 0
+        correct = correct.sum()
+        accuracy = correct / (labels != -100).sum()
+
+        return accuracy
+
     def training_step(self, batch, batch_idx):
         return self._step(batch, batch_idx, "train")
 

+ 164 - 88
fish_speech/models/text2semantic/llama.py

@@ -17,10 +17,9 @@ def find_multiple(n: int, k: int) -> int:
 
 
 @dataclass
-class ModelArgs:
+class BaseModelArgs:
     vocab_size: int = 32000
-    n_slow_layer: int = 32
-    n_fast_layer: int = 4
+    n_layer: int = 32
     n_head: int = 32
     dim: int = 4096
     intermediate_size: int = None
@@ -31,9 +30,10 @@ class ModelArgs:
     max_seq_len: int = 2048
     dropout: float = 0.0
 
-    # Additional decoding heads
+    # Codebook configs
     codebook_size: int = 160
     num_codebooks: int = 4
+    num_in_codebooks: Optional[int] = None
     codebook_padding_idx: int = 0
 
     # Gradient checkpointing
@@ -46,9 +46,21 @@ class ModelArgs:
             hidden_dim = 4 * self.dim
             n_hidden = int(2 * hidden_dim / 3)
             self.intermediate_size = find_multiple(n_hidden, 256)
+        if self.num_in_codebooks is None:
+            self.num_in_codebooks = self.num_codebooks
         self.head_dim = self.dim // self.n_head
 
 
+@dataclass
+class NaiveModelArgs(BaseModelArgs):
+    pass
+
+
+@dataclass
+class DualARModelArgs(BaseModelArgs):
+    n_fast_layer: int = 4
+
+
 class KVCache(nn.Module):
     def __init__(
         self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
@@ -76,42 +88,32 @@ class TransformerForwardResult:
     codebook_logits: Tensor
 
 
-class Transformer(nn.Module):
-    def __init__(self, config: ModelArgs) -> None:
+@dataclass
+class BaseTransformerForwardResult:
+    logits: Tensor
+    hidden_states: Tensor
+
+
+class BaseTransformer(nn.Module):
+    def __init__(self, config: BaseModelArgs) -> None:
         super().__init__()
         self.config = config
 
         # Slow transformer
         self.embeddings = nn.Embedding(
-            config.vocab_size + config.codebook_size * config.num_codebooks,
+            config.vocab_size + config.codebook_size * config.num_in_codebooks,
             config.dim,
         )
-        self.slow_layers = nn.ModuleList(
-            TransformerBlock(config, use_sdpa=True) for _ in range(config.n_slow_layer)
+        self.layers = nn.ModuleList(
+            TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
         )
-        self.slow_norm = RMSNorm(config.dim, eps=config.norm_eps)
-        self.slow_output = nn.Linear(
+        self.norm = RMSNorm(config.dim, eps=config.norm_eps)
+        self.output = nn.Linear(
             config.dim,
             config.vocab_size,
             bias=False,
         )
 
-        # Fast transformer
-        self.fast_embeddings = nn.Embedding(
-            config.codebook_size, config.dim, padding_idx=config.codebook_padding_idx
-        )
-
-        # The equivalent bs is so large that sdpa doesn't work
-        self.fast_layers = nn.ModuleList(
-            TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer)
-        )
-        self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps)
-        self.fast_output = nn.Linear(
-            config.dim,
-            config.codebook_size,
-            bias=False,
-        )
-
         self.register_buffer(
             "freqs_cis",
             precompute_freqs_cis(
@@ -148,8 +150,7 @@ class Transformer(nn.Module):
         self.max_seq_len = max_seq_len
         self.max_batch_size = max_batch_size
 
-        # Slow transformer
-        for b in self.slow_layers:
+        for b in self.layers:
             b.attention.kv_cache = KVCache(
                 max_batch_size,
                 max_seq_len,
@@ -158,24 +159,9 @@ class Transformer(nn.Module):
                 dtype=dtype,
             )
 
-        # Fast transformer
-        # The max seq len here is the number of codebooks
-        for b in self.fast_layers:
-            b.attention.kv_cache = KVCache(
-                max_batch_size,
-                self.config.num_codebooks,
-                self.config.n_local_heads,
-                head_dim,
-                dtype=dtype,
-            )
-
     def embed(self, x: Tensor) -> Tensor:
-        # Here we want to merge the embeddings of the codebooks
-        if self.config.num_codebooks == 0:
-            return self.embeddings(x[:, 0])
-
         vocab_embeds = [self.embeddings(x[:, 0])]
-        for i in range(self.config.num_codebooks):
+        for i in range(self.config.num_in_codebooks):
             emb = self.embeddings(
                 x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
             )
@@ -189,10 +175,9 @@ class Transformer(nn.Module):
 
     def forward(
         self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
-    ) -> TransformerForwardResult:
+    ) -> BaseTransformerForwardResult:
         # x: (batch, num_codebooks + 1, seq_len)
         seq_len = inp.size(2)
-        codebooks = inp[:, 1:]
 
         # Here we want to merge the embeddings of the codebooks
         x = self.embed(inp)
@@ -206,15 +191,136 @@ class Transformer(nn.Module):
         if key_padding_mask is not None:
             mask = mask & key_padding_mask[:, None, None, :].logical_not()
 
-        for layer in self.slow_layers:
+        for layer in self.layers:
             if self.config.use_gradient_checkpointing and self.training:
                 x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
             else:
                 x = layer(x, freqs_cis, mask)
 
         # We got slow_out here
-        slow_out = self.slow_norm(x)
-        token_logits = self.slow_output(slow_out)
+        slow_out = self.norm(x)
+        token_logits = self.output(slow_out)
+
+        return BaseTransformerForwardResult(
+            logits=token_logits,
+            hidden_states=x,
+        )
+
+    def forward_generate(
+        self, x: Tensor, input_pos: Optional[Tensor] = None
+    ) -> BaseTransformerForwardResult:
+        # This is used for generation, optimized for torch compile
+        assert (
+            self.max_seq_len != -1 and self.max_batch_size != -1
+        ), "Please call setup_caches before forward_generate"
+
+        x = self.embed(x)
+
+        mask = self.causal_mask[
+            None, None, input_pos, : self.max_seq_len
+        ]  # (B, N, Q, K)
+        freqs_cis = self.freqs_cis[input_pos]
+
+        for layer in self.layers:
+            x = layer(x, freqs_cis, mask, input_pos=input_pos)
+
+        # If prefill, we only calculate the logits of last token
+        if x.size(1) > 1:
+            x = x[:, -1:]
+
+        # We got slow_out here
+        slow_out = self.norm(x)
+        token_logits = self.output(slow_out)
+
+        return BaseTransformerForwardResult(
+            logits=token_logits,
+            hidden_states=x,
+        )
+
+
+class NaiveTransformer(BaseTransformer):
+    def __init__(self, config: NaiveModelArgs) -> None:
+        super().__init__(config)
+
+        self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
+        self.codebook_output = nn.Linear(
+            config.dim,
+            config.codebook_size * config.num_codebooks,
+            bias=False,
+        )
+
+    def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
+        token_logits = result.logits
+        x = result.hidden_states
+
+        # Codebook
+        codebook_logits = self.codebook_output(self.codebook_norm(x))
+        codebook_logits = rearrange(
+            codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
+        )
+
+        return TransformerForwardResult(
+            token_logits=token_logits,
+            codebook_logits=codebook_logits,
+        )
+
+    def forward(
+        self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
+    ) -> TransformerForwardResult:
+        result = super().forward(inp, key_padding_mask)
+        return self.decode(result)
+
+    def forward_generate(
+        self, x: Tensor, input_pos: Optional[Tensor] = None
+    ) -> TransformerForwardResult:
+        result = super().forward_generate(x, input_pos)
+        return self.decode(result)
+
+
+class DualARTransformer(BaseTransformer):
+    def __init__(self, config: DualARModelArgs) -> None:
+        super().__init__(config)
+
+        # Fast transformer
+        self.fast_embeddings = nn.Embedding(
+            config.codebook_size, config.dim, padding_idx=config.codebook_padding_idx
+        )
+
+        # The equivalent bs is so large that sdpa doesn't work
+        self.fast_layers = nn.ModuleList(
+            TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer)
+        )
+        self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps)
+        self.fast_output = nn.Linear(
+            config.dim,
+            config.codebook_size,
+            bias=False,
+        )
+
+    def setup_caches(
+        self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
+    ):
+        super().setup_caches(max_batch_size, max_seq_len, dtype)
+
+        head_dim = self.config.dim // self.config.n_head
+
+        # Fast transformer
+        # The max seq len here is the number of codebooks
+        for b in self.fast_layers:
+            b.attention.kv_cache = KVCache(
+                max_batch_size,
+                self.config.num_codebooks,
+                self.config.n_local_heads,
+                head_dim,
+                dtype=dtype,
+            )
+
+    def forward(
+        self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
+    ) -> TransformerForwardResult:
+        parent_result = super().forward(inp, key_padding_mask)
+        token_logits = parent_result.logits
+        x = parent_result.hidden_states
 
         # Fast transformer
         fast_seq_len = self.config.num_codebooks
@@ -224,7 +330,7 @@ class Transformer(nn.Module):
         fast_freqs_cis = self.freqs_cis[:fast_seq_len]
 
         # Drop the last token and rotate left
-        codebooks = codebooks[:, :-1, 1:]
+        codebooks = inp[:, 1:-1, 1:]
         codebooks = F.pad(codebooks, (0, 1), value=self.config.codebook_padding_idx)
         codebook_embeddings = self.fast_embeddings(codebooks)
         x = torch.cat([x[:, None], codebook_embeddings], dim=1)
@@ -272,36 +378,6 @@ class Transformer(nn.Module):
             codebook_logits=codebook_logits,
         )
 
-    def forward_generate_slow(
-        self, x: Tensor, input_pos: Optional[Tensor] = None
-    ) -> Tensor:
-        ### TODO: fix this
-        # x: (batch, num_codebooks + 1, 1)
-
-        assert (
-            self.max_seq_len != -1 and self.max_batch_size != -1
-        ), "Please call setup_caches before forward_generate"
-
-        x = self.embed(x)
-
-        mask = self.causal_mask[
-            None, None, input_pos, : self.max_seq_len
-        ]  # (B, N, Q, K)
-        freqs_cis = self.freqs_cis[input_pos]
-
-        for layer in self.slow_layers:
-            x = layer(x, freqs_cis, mask, input_pos=input_pos)
-
-        # If prefill, we only calculate the logits of last token
-        if x.size(1) > 1:
-            x = x[:, -1:]
-
-        # We got slow_out here
-        slow_out = self.slow_norm(x)
-        token_logits = self.slow_output(slow_out)
-
-        return x, token_logits
-
     def forward_generate_fast(
         self, x: Tensor, input_pos: Optional[Tensor] = None
     ) -> Tensor:
@@ -324,7 +400,7 @@ class Transformer(nn.Module):
 
 
 class TransformerBlock(nn.Module):
-    def __init__(self, config: ModelArgs, use_sdpa: bool = True) -> None:
+    def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
         super().__init__()
         self.attention = Attention(config, use_sdpa=use_sdpa)
         self.feed_forward = FeedForward(config)
@@ -340,7 +416,7 @@ class TransformerBlock(nn.Module):
 
 
 class Attention(nn.Module):
-    def __init__(self, config: ModelArgs, use_sdpa: bool = True):
+    def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
         super().__init__()
         assert config.dim % config.n_head == 0
 
@@ -443,7 +519,7 @@ class Attention(nn.Module):
 
 
 class FeedForward(nn.Module):
-    def __init__(self, config: ModelArgs) -> None:
+    def __init__(self, config: BaseModelArgs) -> None:
         super().__init__()
         self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
         self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
@@ -494,10 +570,10 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
 
 
 if __name__ == "__main__":
-    args = ModelArgs(
+    args = DualARModelArgs(
         max_seq_len=4096,
         vocab_size=32312,
-        n_slow_layer=12,
+        n_layer=12,
         n_fast_layer=4,
         n_head=12,
         dim=768,
@@ -507,7 +583,7 @@ if __name__ == "__main__":
         num_codebooks=4,
     )
 
-    model = Transformer(args)
+    model = DualARTransformer(args)
     model = model.cuda().bfloat16()
     print("Total params:", sum(i.numel() for i in model.parameters()) / 1024 / 1024)
 
@@ -516,4 +592,4 @@ if __name__ == "__main__":
     key_padding_mask[0, 2:] = True
     x1 = model(inputs, key_padding_mask=key_padding_mask)
     print(x1.token_logits.shape)
-    # print(x1.codebook_logits.shape)
+    print(x1.codebook_logits.shape)

+ 39 - 84
tools/llama/generate.py

@@ -1,7 +1,7 @@
 import os
 import time
 from pathlib import Path
-from typing import Optional, Tuple
+from typing import Optional, Tuple, Union
 
 import click
 import numpy as np
@@ -25,7 +25,7 @@ if hasattr(torch._inductor.config, "fx_graph_cache"):
     torch._inductor.config.fx_graph_cache = True
 
 
-from fish_speech.models.text2semantic.llama import Transformer
+from fish_speech.models.text2semantic.llama import DualARTransformer, NaiveTransformer
 from fish_speech.text import g2p
 from fish_speech.text.symbols import pad as pad_symbol
 from fish_speech.text.symbols import pu_symbols
@@ -89,8 +89,8 @@ def sample(
     return idx_next, probs
 
 
-def decode_one_token(
-    model: Transformer,
+def decode_one_token_ar(
+    model: NaiveTransformer,
     x: torch.Tensor,
     input_pos: torch.Tensor,
     previous_tokens: torch.Tensor = None,
@@ -115,7 +115,6 @@ def decode_one_token(
     for codebook_idx in range(model.config.num_codebooks):
         input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
         logits = model.forward_generate_fast(x, input_pos)
-        # print(x.shape, logits.shape)
         a = sample(
             logits,
             previous_tokens=(
@@ -127,73 +126,49 @@ def decode_one_token(
         )[0]
         x = model.fast_embeddings(a)
         codebooks.append(a)
-        # x = torch.cat(buffer, dim=1)
-        # logits = model.forward_fast(x)[:, -1:, :]
-        # a = sample(
-        #     logits,
-        #     previous_tokens=(
-        #         previous_tokens[codebook_idx + 1]
-        #         if previous_tokens is not None
-        #         else None
-        #     ),
-        #     **sampling_kwargs,
-        # )[0]
-        # x = model.fast_embeddings(a)
-        # codebooks.append(a)
-        # buffer.append(x.view(1, 1, -1))
 
     return torch.stack(codebooks, dim=0)
 
 
-def prefill(
-    model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
+def decode_one_token_naive(
+    model: NaiveTransformer,
+    x: torch.Tensor,
+    input_pos: torch.Tensor,
+    previous_tokens: torch.Tensor = None,
+    **sampling_kwargs,
 ) -> torch.Tensor:
-    # input_pos: [B, S]
-    x, logits = model.forward_generate_slow(x, input_pos)
+    assert input_pos.shape[-1] == 1
 
+    x, logits = model.forward_generate_slow(x, input_pos)
     codebooks = [
         sample(
             logits,
-            previous_tokens=None,
+            previous_tokens=None,  # Disable repetition penalty for the token codebook
             **sampling_kwargs,
         )[0]
     ]
 
-    # Cleanup the cache
-    for layer in model.fast_layers:
-        layer.attention.kv_cache.k_cache.fill_(0)
-        layer.attention.kv_cache.v_cache.fill_(0)
-
-    for codebook_idx in range(model.config.num_codebooks):
-        input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
-        logits = model.forward_generate_fast(x, input_pos)
-        # print(x.shape, logits.shape)
-        a = sample(
-            logits,
-            previous_tokens=None,
-            **sampling_kwargs,
-        )[0]
-        x = model.fast_embeddings(a)
-        codebooks.append(a)
-        # x = torch.cat(buffer, dim=1)
-        # logits = model.forward_fast(x)[:, -1:, :]
-        # a = sample(
-        #     logits,
-        #     **sampling_kwargs,
-        # )[0]
-        # x = model.fast_embeddings(a)
-        # codebooks.append(a)
-        # buffer.append(x.view(1, 1, -1))
+    for i in range(model.config.num_codebooks):
+        codebooks.append(
+            sample(
+                logits.codebook_logits[:, :, i],
+                previous_tokens=previous_tokens[i + 1]
+                if previous_tokens is not None
+                else None,
+                **sampling_kwargs,
+            )[0]
+        )
 
     return torch.stack(codebooks, dim=0)
 
 
 def decode_n_tokens(
-    model: Transformer,
+    model: NaiveTransformer,
     cur_token: torch.Tensor,
     input_pos: torch.Tensor,
     num_new_tokens: int,
     eos_token_id: int = 2,
+    decode_one_token=decode_one_token_naive,
     **sampling_kwargs,
 ):
     previous_tokens = torch.zeros(
@@ -238,10 +213,11 @@ def decode_n_tokens(
 @torch.inference_mode()
 def generate(
     *,
-    model: Transformer,
+    model: NaiveTransformer,
     prompt: torch.Tensor,
     max_new_tokens: int,
     eos_token_id: int = 2,
+    decode_one_token=decode_one_token_naive,
     precision: torch.dtype = torch.bfloat16,
     **sampling_kwargs,
 ) -> torch.Tensor:
@@ -273,7 +249,7 @@ def generate(
     seq = empty
     input_pos = torch.arange(0, T, device=device)
 
-    next_token = prefill(
+    next_token = decode_one_token(
         model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
     )
     seq[:, T : T + 1] = next_token
@@ -285,6 +261,7 @@ def generate(
         input_pos,
         max_new_tokens - 1,
         eos_token_id=eos_token_id,
+        decode_one_token=decode_one_token,
         **sampling_kwargs,
     )
     # x = torch.cat(generated_tokens, dim=1)
@@ -320,20 +297,6 @@ def encode_tokens(
         string = f"[SPK: {speaker}] {string}"
 
     string = f"[INST] {string} [/INST]"
-
-    # Handle English less frequent words
-    # TODO: update tokenizer to handle this
-    # sub_strings = string.split(" ")
-    # new_tokens = []
-    # for i, string in enumerate(sub_strings):
-    #     tokens = tokenizer.encode(
-    #         string,
-    #         add_special_tokens=i == 0 and bos,
-    #         max_length=10**6,
-    #         truncation=False,
-    #     )
-    #     new_tokens.extend(tokens)
-
     new_tokens = tokenizer.encode(
         string,
         add_special_tokens=bos,
@@ -384,7 +347,7 @@ def load_model(config_name, checkpoint_path, device, precision):
     with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
         cfg = compose(config_name=config_name)
 
-    model: Transformer = instantiate(cfg.model).model
+    model: Union[NaiveTransformer, DualARTransformer] = instantiate(cfg.model).model
 
     if "int8" in str(checkpoint_path):
         logger.info("Using int8 weight-only quantization!")
@@ -523,7 +486,7 @@ def main(
                 num_codebooks=model.config.num_codebooks,
             )
         )
-        print(f"Encoded text: {text}")
+        logger.info(f"Encoded text: {text}")
 
     if use_prompt:
         encoded_prompt = encode_tokens(
@@ -540,14 +503,16 @@ def main(
 
         encoded[0] = torch.cat((encoded_prompt, encoded[0]), dim=1)
 
-    # prompt_length = encoded.size(1)
-    # logger.info(f"Encoded prompt shape: {encoded.shape}")
-
     torch.manual_seed(seed)
     torch.cuda.manual_seed(seed)
 
+    decode_one_token = (
+        decode_one_token_ar
+        if isinstance(model, DualARTransformer)
+        else decode_one_token_naive
+    )
+
     if compile:
-        global decode_one_token
         decode_one_token = torch.compile(
             decode_one_token, mode="reduce-overhead", fullgraph=True
         )
@@ -573,12 +538,11 @@ def main(
             if i != 0 and i % 2 == 0:
                 i -= 1
 
+            # Rotate the list
             if i < len(global_encoded) - 2:
                 partial_encoded = global_encoded[-i:]
-                print(f"Loaded partial encoded")
             else:
                 partial_encoded = global_encoded
-                print(f"Using full encoded")
 
             cat_encoded = torch.cat(partial_encoded, dim=1)
             prompt_length = cat_encoded.size(1)
@@ -589,6 +553,7 @@ def main(
                 prompt=cat_encoded,
                 max_new_tokens=max_new_tokens,
                 eos_token_id=tokenizer.eos_token_id,
+                decode_one_token=decode_one_token,
                 precision=precision,
                 temperature=temperature,
                 top_k=top_k,
@@ -617,15 +582,6 @@ def main(
             # Put the generated tokens
             codes = y[1:, prompt_length:-1].clone()
 
-            # if getattr(cfg, "use_delay_pattern", True):
-            #     new_codes = []
-            #     for j, code in enumerate(codes):
-            #         new_codes.append(
-            #             code[j : codes.shape[1] - (model.config.num_codebooks - j - 1)]
-            #         )
-
-            #     codes = torch.stack(new_codes, dim=0)
-
             codes = codes - 2
             if not (codes >= 0).all():
                 global_encoded.pop()
@@ -638,7 +594,6 @@ def main(
 
         codes = torch.cat(all_codes, dim=1)
         assert (codes >= 0).all(), f"Negative code found: {codes}"
-        print(codes)
 
         np.save(f"codes_{idx}.npy", codes.cpu().numpy())
         logger.info(f"Saved codes to codes_{idx}.npy")