فهرست منبع

Fix audio quality bug (#1014)

* Update inference.py

* Update llama.py

* Update content_sequence.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Stardust·减 10 ماه پیش
والد
کامیت
83afe683e4
3فایلهای تغییر یافته به همراه251 افزوده شده و 200 حذف شده
  1. 8 3
      fish_speech/content_sequence.py
  2. 162 153
      fish_speech/models/text2semantic/inference.py
  3. 81 44
      fish_speech/models/text2semantic/llama.py

+ 8 - 3
fish_speech/content_sequence.py

@@ -271,7 +271,7 @@ class ContentSequence:
         self: "ContentSequence",
         tokenizer: FishTokenizer,
         num_codebooks: int,
-    ) -> torch.Tensor:
+    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         encoded = self.encode(tokenizer, add_shift=False)
         tokens = encoded.tokens
         values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int)
@@ -280,8 +280,9 @@ class ContentSequence:
         if (encoded.vq_parts is None or len(encoded.vq_parts) == 0) and (
             encoded.audio_parts is None or len(encoded.audio_parts) == 0
         ):
-            return values
+            return values, None, None
 
+        audio_parts = audio_masks = None
         if encoded.vq_parts is not None and len(encoded.vq_parts) > 0:
             vq_parts = encoded.vq_parts
             vq_parts = torch.cat(vq_parts, dim=1)
@@ -290,7 +291,11 @@ class ContentSequence:
             )
             values[1:, encoded.vq_mask_tokens] = vq_parts
 
-        return values
+        if encoded.audio_parts is not None and len(encoded.audio_parts) > 0:
+            audio_parts = torch.cat(encoded.audio_parts, dim=0)
+            audio_masks = encoded.audio_masks[None, :]
+
+        return values, audio_masks, audio_parts
 
     def visualize(
         self: "ContentSequence",

+ 162 - 153
fish_speech/models/text2semantic/inference.py

@@ -2,6 +2,7 @@ import os
 import queue
 import threading
 import time
+import traceback
 from contextlib import nullcontext
 from dataclasses import dataclass
 from pathlib import Path
@@ -35,6 +36,7 @@ if hasattr(torch._inductor.config, "fx_graph_cache"):
 from torch.nn.attention import SDPBackend, sdpa_kernel
 
 from fish_speech.models.text2semantic.llama import (
+    BaseTransformer,
     DualARTransformer,
     NaiveTransformer,
 )
@@ -49,19 +51,19 @@ def multinomial_sample_one_no_sync(
 
 def logits_to_probs(
     logits,
+    temperature: torch.Tensor,
+    top_p: torch.Tensor,
+    repetition_penalty: torch.Tensor,
     previous_tokens: Optional[torch.Tensor] = None,
-    temperature: torch.Tensor = 1.0,
-    top_p: torch.Tensor = 1.0,
-    repetition_penalty: torch.Tensor = 1.0,
 ) -> torch.Tensor:
     # Apply repetition penalty
     if previous_tokens is not None:
         previous_tokens = previous_tokens.long()
-        score = torch.gather(logits, dim=0, index=previous_tokens)
+        score = torch.gather(logits, dim=-1, index=previous_tokens)
         score = torch.where(
             score < 0, score * repetition_penalty, score / repetition_penalty
         )
-        logits.scatter_(dim=0, index=previous_tokens, src=score)
+        logits.scatter_(dim=-1, index=previous_tokens, src=score)
 
     # Apply top-p sampling
     sorted_logits, sorted_indices = torch.sort(logits, descending=True)
@@ -69,11 +71,10 @@ def logits_to_probs(
     sorted_indices_to_remove = cum_probs > top_p
     sorted_indices_to_remove[0] = False  # keep at least one option
     indices_to_remove = sorted_indices_to_remove.scatter(
-        dim=0, index=sorted_indices, src=sorted_indices_to_remove
+        dim=-1, index=sorted_indices, src=sorted_indices_to_remove
     )
     logits = logits.masked_fill(indices_to_remove, -float("Inf"))
-
-    logits = logits / max(temperature, 1e-5)
+    logits = logits / torch.clip(temperature, min=1e-5)
 
     probs = torch.nn.functional.softmax(logits, dim=-1)
     return probs
@@ -81,11 +82,17 @@ def logits_to_probs(
 
 def sample(
     logits,
+    temperature: torch.Tensor,
+    top_p: torch.Tensor,
+    repetition_penalty: torch.Tensor,
     previous_tokens: Optional[torch.Tensor] = None,
-    **sampling_kwargs,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
     probs = logits_to_probs(
-        logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
+        logits=logits[0, -1],
+        temperature=temperature,
+        top_p=top_p,
+        repetition_penalty=repetition_penalty,
+        previous_tokens=previous_tokens,
     )
     idx_next = multinomial_sample_one_no_sync(probs)
     return idx_next, probs
@@ -95,40 +102,35 @@ def decode_one_token_ar(
     model: DualARTransformer,
     x: torch.Tensor,
     input_pos: torch.Tensor,
+    temperature: torch.Tensor,
+    top_p: torch.Tensor,
+    repetition_penalty: torch.Tensor,
+    audio_masks: torch.Tensor,
+    audio_parts: torch.Tensor,
     previous_tokens: torch.Tensor = None,
-    **sampling_kwargs,
 ) -> torch.Tensor:
-    """
-    Generate one token using dual autoregressive transformer for text-to-speech.
-
-    First generates semantic tokens, then generates acoustic codebook tokens sequentially.
-
-    Args:
-        x: Input token tensor (1, num_codebooks+1, seq_len)
-        input_pos: Position indices for input tokens (seq_len,)
-        temperature/top_p/repetition_penalty: Sampling parameters (1, 1)
-        previous_tokens: Previous tokens for repetition penalty (1, num_codebooks+1, history_seq_len)
-        audio_masks/audio_parts: Audio conditioning tensors (num_codebooks, seq_len)
-
-    Returns:
-        Generated tokens tensor (num_codebooks+1, 1) - one token per codebook
-    """
-    x = model.forward_generate(x, input_pos)
-
-    sampling_kwargs_main = sampling_kwargs.copy()
+    # print(x, torch.count_nonzero(vq_masks))
+    x = model.forward_generate(
+        x,
+        input_pos,
+        audio_masks=audio_masks,
+        audio_parts=audio_parts,
+    )
+    logits = x.logits  # [:, -1:]
+    hidden_states = x.hidden_states  # [:, -1:]
 
     codebooks = [
         sample(
-            x.logits,
+            logits,
+            temperature=temperature,
+            top_p=top_p,
+            repetition_penalty=repetition_penalty,
             previous_tokens=(
-                previous_tokens[0] if previous_tokens is not None else None
-            ),  # Disable repetition penalty for the token codebook
-            **sampling_kwargs_main,
+                previous_tokens[:, 0] if previous_tokens is not None else None
+            ),
         )[0]
     ]
 
-    hidden_states = x.hidden_states
-
     # Cleanup the cache
     for layer in model.fast_layers:
         layer.attention.kv_cache.k_cache.fill_(0)
@@ -146,22 +148,27 @@ def decode_one_token_ar(
             [codebook_idx], device=hidden_states.device, dtype=torch.long
         )
         logits = model.forward_generate_fast(hidden_states, input_pos)
-        chunked_logits = logits[..., :1024]
+
+        short_logits = logits[:, :, :1024]
+
+        # Convert logits to probs
         a = sample(
-            chunked_logits,
+            short_logits,
+            temperature=temperature,
+            top_p=top_p,
+            repetition_penalty=repetition_penalty,
             previous_tokens=(
                 previous_tokens[codebook_idx + 1]
                 if previous_tokens is not None
                 else None
             ),
-            **sampling_kwargs,
         )[0]
+
         hidden_states = model.fast_embeddings(a)
         codebooks.append(a)
 
-    codebooks = torch.stack(codebooks, dim=0)
-
-    return codebooks
+    codebooks = torch.stack(codebooks, dim=1)
+    return codebooks.T
 
 
 def decode_n_tokens(
@@ -169,24 +176,13 @@ def decode_n_tokens(
     cur_token: torch.Tensor,
     input_pos: torch.Tensor,
     num_new_tokens: int,
+    temperature: torch.Tensor,
+    top_p: torch.Tensor,
+    repetition_penalty: torch.Tensor,
+    audio_masks: torch.Tensor,
+    audio_parts: torch.Tensor,
     decode_one_token=decode_one_token_ar,
-    **sampling_kwargs,
 ):
-    """
-    Generate n tokens iteratively using the model.
-
-    Args:
-        model: The transformer model
-        cur_token: Current token tensor of shape (1, num_codebooks+1, seq_len)
-        input_pos: Current input position tensor
-        num_new_tokens: Number of new tokens to generate
-        semantic_ids: List of semantic token IDs
-        decode_one_token: Function to decode one token
-        **sampling_kwargs: Additional sampling parameters
-
-    Returns:
-        Generated tokens tensor of shape (num_codebooks+1, generated_len)
-    """
     previous_tokens = torch.zeros(
         (model.config.num_codebooks + 1, model.config.max_seq_len),
         dtype=torch.int,
@@ -201,13 +197,19 @@ def decode_n_tokens(
         else:
             window = previous_tokens[:, i - win_size : i]
 
-        with sdpa_kernel(SDPBackend.MATH):
+        with sdpa_kernel(
+            SDPBackend.MATH
+        ):  # Actually better for Inductor to codegen attention here
             next_token = decode_one_token(
                 model=model,
                 x=cur_token,
                 input_pos=input_pos,
                 previous_tokens=window,
-                **sampling_kwargs,
+                temperature=temperature,
+                top_p=top_p,
+                repetition_penalty=repetition_penalty,
+                audio_masks=audio_masks,
+                audio_parts=audio_parts,
             ).clone()
 
         input_pos += 1
@@ -226,33 +228,31 @@ def decode_n_tokens(
 @torch.inference_mode()
 def generate(
     *,
-    model: NaiveTransformer,
+    model: BaseTransformer,
     prompt: torch.Tensor,
     max_new_tokens: int,
+    audio_masks: torch.Tensor,
+    audio_parts: torch.Tensor,
     decode_one_token=decode_one_token_ar,
+    num_samples: int = 1,
     **sampling_kwargs,
-) -> torch.Tensor:
+):
     """
-    Generate tokens from text prompt using the transformer model.
-
-    Args:
-        model: The transformer model for generation
-        prompt: Input token tensor of shape (num_codebooks+1, seq_len)
-        max_new_tokens: Maximum number of new tokens to generate
-        decode_one_token: Function to decode one token at a time
-        **sampling_kwargs: Additional sampling parameters (temperature, top_p, repetition_penalty)
-
-    Returns:
-        Generated sequence tensor of shape (num_codebooks+1, total_seq_len)
-        where total_seq_len = original_seq_len + generated_tokens_len
+    Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
     """
 
+    # create an empty tensor of the expected final shape and fill in the current tokens
     T = prompt.size(1)
+    prompt = prompt[None].repeat(num_samples, 1, 1)
+
+    if T >= model.config.max_seq_len:
+        raise ValueError(
+            f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
+        )
 
     if max_new_tokens:
         if T + max_new_tokens > model.config.max_seq_len:
             max_new_tokens = model.config.max_seq_len - T
-            logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
 
         T_new = T + max_new_tokens
     else:
@@ -260,23 +260,40 @@ def generate(
         max_new_tokens = T_new - T
 
     device, dtype = prompt.device, prompt.dtype
+    with torch.device(device):
+        model.setup_caches(
+            max_batch_size=num_samples,
+            max_seq_len=model.config.max_seq_len,
+            dtype=next(model.parameters()).dtype,
+        )
 
     codebook_dim = 1 + model.config.num_codebooks
+    input_pos = torch.arange(0, T, device=device)
     empty = torch.empty(
         (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
     )
     empty[:, :T] = prompt
     seq = empty
-    input_pos = torch.arange(0, T, device=device)
 
-    # Use non-accelerated version for now, to avoid compilation overhead
+    temperature = torch.tensor(
+        sampling_kwargs["temperature"], device=device, dtype=torch.bfloat16
+    )
+    top_p = torch.tensor(sampling_kwargs["top_p"], device=device, dtype=torch.bfloat16)
+    repetition_penalty = torch.tensor(
+        sampling_kwargs["repetition_penalty"], device=device, dtype=torch.bfloat16
+    )
+
     prefill_decode = decode_one_token_ar
 
     first_token = prefill_decode(
         model,
         prompt.view(1, codebook_dim, -1),
         input_pos,
-        **sampling_kwargs,
+        temperature,
+        top_p,
+        repetition_penalty,
+        audio_masks,
+        audio_parts,
     )
     seq[:, T : T + 1] = first_token
 
@@ -286,12 +303,15 @@ def generate(
         first_token.view(1, codebook_dim, -1),
         input_pos,
         max_new_tokens - 1,
+        temperature=temperature,
+        top_p=top_p,
+        repetition_penalty=repetition_penalty,
+        audio_masks=audio_masks,
+        audio_parts=audio_parts,
         decode_one_token=decode_one_token,
-        **sampling_kwargs,
     )
     seq = seq[:, : T + 1 + x.size(1)]
     seq[:, T + 1 :] = x
-
     return seq
 
 
@@ -303,17 +323,26 @@ def init_model(checkpoint_path, device, precision, compile=False):
 
     if isinstance(model, DualARTransformer):
         decode_one_token = decode_one_token_ar
+        prefill_n_tokens = decode_one_token_ar
         logger.info("Using DualARTransformer")
     else:
-        raise ValueError("Model is not a DualARTransformer")
+        raise ValueError("Unsupported model type")
+
+    # Initialize cache
+    with torch.device(device):
+        model.setup_caches(
+            max_batch_size=1,
+            max_seq_len=model.config.max_seq_len,
+            dtype=next(model.parameters()).dtype,
+        )
 
     if compile:
         logger.info("Compiling function...")
         decode_one_token = torch.compile(
             decode_one_token,
+            # mode="max-autotune-no-cudagraphs",
+            mode="reduce-overhead",
             fullgraph=True,
-            backend="inductor" if torch.cuda.is_available() else "aot_eager",
-            mode="reduce-overhead" if torch.cuda.is_available() else None,
         )
 
     return model.eval(), decode_one_token
@@ -362,9 +391,7 @@ def generate_long(
     tokenizer = model.tokenizer
     base_content_sequence = ContentSequence(modality="interleave")
 
-    texts = split_text(text, chunk_length) if iterative_prompt else [text]
     max_length = model.config.max_seq_len
-
     if use_prompt:
         for t, c in zip(prompt_text, prompt_tokens):
             base_content_sequence.append(
@@ -373,26 +400,24 @@ def generate_long(
                     VQPart(codes=c),
                 ],
                 add_end=True,
+                speaker=0,
             )
+    base_content_sequence.append(
+        [
+            TextPart(text=text),
+        ],
+        add_end=False,
+        speaker=0,
+    )
 
-    encoded_prompts = base_content_sequence.encode_for_inference(
+    encoded, audio_masks, audio_parts = base_content_sequence.encode_for_inference(
         tokenizer, num_codebooks=model.config.num_codebooks
     )
-    if encoded_prompts.size(1) > max_length - 2048:
-        raise ValueError(
-            f"Prompt is too long: {encoded_prompts.size(1)} > {max_length - 2048}"
-        )
+    if encoded.size(1) > max_length - 2048:
+        raise ValueError(f"Prompt is too long: {encoded.size(1)} > {max_length - 2048}")
 
-    encoded = []
-    for text in texts:
-        content_sequence = ContentSequence(modality="text")
-        content_sequence.append(TextPart(text=text))
-        encoded.append(
-            content_sequence.encode_for_inference(
-                tokenizer, num_codebooks=model.config.num_codebooks
-            )
-        )
-        logger.info(f"Encoded text: {text}")
+    encoded = encoded.to(device=device)
+    logger.info(f"Encoded text: {text}")
 
     # Move temperature, top_p, repetition_penalty to device
     # This is important so that changing params doesn't trigger recompile
@@ -408,70 +433,53 @@ def generate_long(
 
         global_encoded = []
         seg_idx = 0
+        prompt_length = encoded.size(1)
+
+        t0 = time.perf_counter()
+        y = generate(
+            model=model,
+            prompt=encoded,
+            max_new_tokens=max_new_tokens,
+            audio_masks=audio_masks,
+            audio_parts=audio_parts,
+            decode_one_token=decode_one_token,
+            temperature=temperature,
+            top_p=top_p,
+            repetition_penalty=repetition_penalty,
+        )
 
-        while seg_idx < len(encoded):
-            logger.info(
-                f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
-            )
-
-            seg = encoded[seg_idx]
-            global_encoded.append(seg)
-
-            if len(base_content_sequence.parts) <= 1 and len(global_encoded) >= 2:
-                cat_encoded = torch.cat(
-                    [encoded_prompts, global_encoded[0], global_encoded[1], seg], dim=1
-                )
-            else:
-                cat_encoded = torch.cat([encoded_prompts, seg], dim=1)
-
-            cat_encoded = cat_encoded.to(device=device)
-            prompt_length = cat_encoded.size(1)
-
-            t0 = time.perf_counter()
-            y = generate(
-                model=model,
-                prompt=cat_encoded,
-                max_new_tokens=max_new_tokens,
-                decode_one_token=decode_one_token,
-                temperature=temperature,
-                top_p=top_p,
-                repetition_penalty=repetition_penalty,
-            )
+        if sample_idx == 0 and seg_idx == 0 and compile:
+            logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
 
-            if sample_idx == 0 and seg_idx == 0 and compile:
-                logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
+        if torch.cuda.is_available():
+            torch.cuda.synchronize()
 
-            if torch.cuda.is_available():
-                torch.cuda.synchronize()
+        t = time.perf_counter() - t0
 
-            t = time.perf_counter() - t0
+        tokens_generated = y.size(1) - prompt_length
+        tokens_sec = tokens_generated / t
+        logger.info(
+            f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
+        )
+        logger.info(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
 
-            tokens_generated = y.size(1) - prompt_length
-            tokens_sec = tokens_generated / t
-            logger.info(
-                f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
-            )
+        if torch.cuda.is_available():
             logger.info(
-                f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
+                f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
             )
 
-            if torch.cuda.is_available():
-                logger.info(
-                    f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
-                )
-
-            # Put the generated tokens
-            # since there is <im_end>, we remove last token
-            codes = y[1:, prompt_length:-1].clone()
-            assert (codes >= 0).all(), f"Negative code found"
+        # Put the generated tokens
+        # since there is <im_end>, we remove last token
+        codes = y[1:, prompt_length:-1].clone()
+        assert (codes >= 0).all(), f"Negative code found"
 
-            decoded = y[:, prompt_length:].clone()
-            # But for global encoding, we should keep the <im_end> token
+        decoded = y[:, prompt_length:].clone()
+        # But for global encoding, we should keep the <im_end> token
 
-            global_encoded.append(decoded.cpu())
-            assert (codes >= 0).all(), f"Negative code found: {codes}"
-            yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
-            seg_idx += 1
+        global_encoded.append(decoded.cpu())
+        assert (codes >= 0).all(), f"Negative code found: {codes}"
+        yield GenerateResponse(action="sample", codes=codes, text=text)
+        seg_idx += 1
 
         # This indicates the end of the current sample
         yield GenerateResponse(action="next")
@@ -526,6 +534,7 @@ def launch_thread_safe_queue(
                         WrappedGenerateResponse(status="success", response=chunk)
                     )
             except Exception as e:
+                logger.error(traceback.format_exc())
                 response_queue.put(WrappedGenerateResponse(status="error", response=e))
 
     threading.Thread(target=worker, daemon=True).start()

+ 81 - 44
fish_speech/models/text2semantic/llama.py

@@ -320,9 +320,45 @@ class BaseTransformer(nn.Module):
         self,
         inp: Tensor,
         input_pos: Optional[Tensor] = None,
+        audio_masks: Optional[Tensor] = None,
+        audio_parts: Optional[Tensor] = None,
         return_all: bool = False,
     ) -> BaseTransformerForwardResult:
-        x = self.embed(inp)
+        # This is used for generation, optimized for torch compile
+        # assert (
+        #     self.max_seq_len != -1 and self.max_batch_size != -1
+        # ), "Please call setup_caches before forward_generate"
+
+        embeds = []
+        for i in range(self.config.num_codebooks):
+            emb = self.codebook_embeddings(
+                inp[:, i + 1] + i * self.config.codebook_size
+            )
+            embeds.append(emb)
+
+        vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
+
+        vq_masks = (inp[:, 0] >= self.tokenizer.semantic_begin_id) & (
+            inp[:, 0] <= self.tokenizer.semantic_end_id
+        )
+
+        vq_embeds_sum[~vq_masks] = 0
+        x = self.embeddings(inp[:, 0]) + vq_embeds_sum
+
+        if self.config.scale_codebook_embeddings:
+            # Expand vq_masks to match x's shape
+            vq_masks_expanded = vq_masks.unsqueeze(-1).expand_as(x)
+            x = torch.where(
+                vq_masks_expanded, x / math.sqrt(self.config.num_codebooks + 1), x
+            )
+
+        # Audio embeddings
+        if audio_parts is not None:
+            audio_embeds = self.audio_projector(audio_parts)
+            if self.config.scale_codebook_embeddings:
+                x[audio_masks] = audio_embeds / math.sqrt(2)
+            else:
+                x[audio_masks] = audio_embeds
 
         if input_pos is None:
             input_pos = torch.arange(inp.shape[-1], device=x.device)
@@ -595,69 +631,69 @@ class DualARTransformer(BaseTransformer):
     def forward(
         self,
         inp: Tensor,
+        labels: Optional[Tensor] = None,
         key_padding_mask: Optional[Tensor] = None,
+        vq_parts: Optional[Tensor] = None,
+        vq_masks: Optional[Tensor] = None,
+        vq_require_losses: Optional[Tensor] = None,
+        mel_parts: Optional[Tensor] = None,
+        mel_masks: Optional[Tensor] = None,
     ) -> TransformerForwardResult:
-        parent_result = super().forward(inp, key_padding_mask)
+        parent_result = super().forward(
+            inp=inp,
+            key_padding_mask=key_padding_mask,
+            vq_parts=vq_parts,
+            vq_masks=vq_masks,
+            mel_parts=mel_parts,
+            mel_masks=mel_masks,
+        )
         token_logits = parent_result.logits
         x = parent_result.hidden_states
-        x = self.fast_project_in(x)
 
         # Fast transformer
         fast_seq_len = self.config.num_codebooks
         fast_mask = self.causal_mask[
             None, None, :fast_seq_len, :fast_seq_len
         ]  # (B, N, Q, K)
+        fast_freqs_cis = self.fast_freqs_cis[:fast_seq_len]
+
+        # Extract corresponding parts with labels
+        codebook_mask = labels == self.semantic_token_id
+        # This gives where input token is <|semantic|>
+        x = x[codebook_mask]
+
+        if x.shape[0] == 0:
+            # Use dummy input when no vq is required
+            x = torch.zeros(
+                (4, self.config.dim),
+                device=x.device,
+                dtype=x.dtype,
+            )
+            codebooks = torch.zeros(
+                (x.shape[0], self.config.num_codebooks - 1),
+                device=x.device,
+                dtype=torch.int,
+            )
+        else:
+            codebooks = vq_parts[..., :-1][vq_require_losses][
+                vq_masks[vq_require_losses]
+            ]
 
-        # Drop the last token and rotate left
-        codebooks = inp[:, 1:-1, 1:]
-        codebooks = F.pad(codebooks, (0, 1), value=0)
+        x = self.fast_project_in(x)
         codebook_embeddings = self.fast_embeddings(codebooks)
         x = torch.cat([x[:, None], codebook_embeddings], dim=1)
-        b, s = x.size(0), x.size(2)
-        x = rearrange(x, "b n s d -> (b s) n d")  # flatten the batch and seq_len
-
-        # Remove padded part
-        codebooks = rearrange(codebooks, "b n s -> (b s) n")
-        codebook_mask = (codebooks == 0).all(dim=-1)
-
-        if torch.all(codebook_mask):
-            # If all codebooks are padded, we keep first 8 to make sure the model runs
-            codebook_mask[:8] = False
-
-        x_bs, x_len = x.size(0), x.size(1)
-        x = x[~codebook_mask]
 
         for layer in self.fast_layers:
             if self.config.use_gradient_checkpointing and self.training:
-                x = checkpoint(
-                    layer, x, self.fast_freqs_cis, fast_mask, use_reentrant=True
-                )
+                x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
             else:
-                x = layer(x, self.fast_freqs_cis, fast_mask)
+                x = layer(x, fast_freqs_cis, fast_mask)
 
         # unflatten the batch and num_codebooks
         fast_out = self.fast_norm(x)
         codebook_logits = self.fast_output(fast_out)
 
-        # Re-pad the codebook_logits
-        buffer = torch.zeros(
-            x_bs,
-            x_len,
-            codebook_logits.size(-1),
-            device=codebook_logits.device,
-            dtype=codebook_logits.dtype,
-        )
-        buffer[~codebook_mask] = codebook_logits
-        codebook_logits = buffer
-
         assert codebook_logits.shape[1] == self.config.num_codebooks
-        codebook_logits = rearrange(
-            codebook_logits,
-            "(b s) n d -> b s n d",
-            b=b,
-            s=s,
-            n=self.config.num_codebooks,
-        )
 
         return TransformerForwardResult(
             token_logits=token_logits,
@@ -668,7 +704,7 @@ class DualARTransformer(BaseTransformer):
         self, x: Tensor, input_pos: Optional[Tensor] = None
     ) -> Tensor:
         # Fast transformer
-        x = x.view(1, 1, -1)
+        x = x.view(x.shape[0], 1, -1)
 
         fast_mask = self.causal_mask[
             None, None, input_pos, : self.config.num_codebooks
@@ -688,9 +724,10 @@ class DualARTransformer(BaseTransformer):
         self,
         x: Tensor,
         input_pos: Optional[Tensor] = None,
-        vq_masks: Optional[Tensor] = None,
+        audio_masks: Optional[Tensor] = None,
+        audio_parts: Optional[Tensor] = None,
     ) -> TransformerForwardResult:
-        x = super().forward_generate(x, input_pos, vq_masks)
+        x = super().forward_generate(x, input_pos, audio_masks, audio_parts)
         x.hidden_states = self.fast_project_in(x.hidden_states)
         return x