Просмотр исходного кода

Add llama inference tool chain

Lengyue 2 лет назад
Родитель
Сommit
62710e34a4

+ 11 - 0
LICENSE

@@ -0,0 +1,11 @@
+Copyright 2023 Lengyue
+
+Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
+
+3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

+ 2 - 0
README.md

@@ -21,3 +21,5 @@ pip3 install -e .
 - [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
 - [GPT VITS](https://github.com/innnky/gpt-vits)
 - [MQTTS](https://github.com/b04901014/MQTTS)
+- [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+

+ 4 - 4
fish_speech/configs/text2semantic.yaml

@@ -2,7 +2,7 @@ defaults:
   - base
   - _self_
 
-project: text2semantic_100m
+project: text2semantic_400m
 
 # Lightning Trainer
 trainer:
@@ -50,9 +50,9 @@ model:
       _target_: fish_speech.models.text2semantic.llama.ModelArgs
       max_seq_len: 4096
       vocab_size: 32312
-      n_layer: 12
-      n_head: 12
-      dim: 768
+      n_layer: 24
+      n_head: 16
+      dim: 1024
       rope_base: 10000
       norm_eps: 1e-5
       codebook_size: 168

+ 12 - 5
fish_speech/datasets/text.py

@@ -201,15 +201,19 @@ class AutoAugTextDataset(IterableDataset):
             )
 
         tokens = self.tokenizer.encode(
-            f"{sentence}", max_length=10**6, add_special_tokens=False
+            f"{sentence}",
+            max_length=10**6,
+            add_special_tokens=False,
+            truncation=False,
         )
         return sentence, len(tokens)
 
     def augment(self, group):
         # 50% to pure text or pure phones
-        mode = "sample"
-        if random.random() < 0.5:
-            mode = random.choice(["text", "phones"])
+        # mode = "sample"
+        # if random.random() < 0.5:
+        #     mode = random.choice(["text", "phones"])
+        mode = "phones"
 
         # Random sample based on speaker using a truncated normal distribution
         a = torch.tensor([0], dtype=torch.float32)
@@ -243,7 +247,10 @@ class AutoAugTextDataset(IterableDataset):
 
         final_text = "[INST] " + "<pad>".join(final_text) + " [/INST]"
         encoded = self.tokenizer.encode(
-            final_text, max_length=self.max_length, add_special_tokens=False
+            final_text,
+            max_length=self.max_length,
+            add_special_tokens=False,
+            truncation=False,
         )
         semantic_length = sum([len(i[0].values) for i in final_semantic])
 

+ 407 - 0
fish_speech/models/text2semantic/generate.py

@@ -0,0 +1,407 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+import itertools
+import sys
+import time
+from pathlib import Path
+from typing import Optional, Tuple
+
+import torch
+import torch._dynamo.config
+import torch._inductor.config
+from transformers import AutoTokenizer
+
+torch._inductor.config.coordinate_descent_tuning = True
+torch._inductor.config.triton.unique_kernel_names = True
+torch._inductor.config.fx_graph_cache = True  # Experimental feature to reduce compilation times, will be on by default in future
+
+
+from fish_speech.models.text2semantic.llama import ModelArgs, Transformer
+from fish_speech.models.text2semantic.tp import maybe_init_dist
+from fish_speech.text import g2p
+from fish_speech.text.symbols import pad as pad_symbol
+from fish_speech.text.symbols import pu_symbols
+
+
+def multinomial_sample_one_no_sync(
+    probs_sort,
+):  # Does multinomial sampling without a cuda synchronization
+    q = torch.empty_like(probs_sort).exponential_(1)
+    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
+
+
+def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
+    logits = logits / max(temperature, 1e-5)
+
+    if top_k is not None:
+        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
+        pivot = v.select(-1, -1).unsqueeze(-1)
+        logits = torch.where(logits < pivot, -float("Inf"), logits)
+    probs = torch.nn.functional.softmax(logits, dim=-1)
+    return probs
+
+
+def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
+    probs = logits_to_probs(logits[0, -1], temperature, top_k)
+    idx_next = multinomial_sample_one_no_sync(probs)
+    return idx_next, probs
+
+
+def decode_token(
+    model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
+) -> torch.Tensor:
+    # input_pos: [B, S]
+    logits = model.forward_generate(x, input_pos)
+    codebooks = [sample(logits.token_logits, **sampling_kwargs)[0]]
+
+    # Disable <s> and </s> tokens for 2-n codebooks
+    logits.codebook_logits[:, :, 1:, :2] = -float("Inf")
+    for i in range(model.config.num_codebooks):
+        codebooks.append(sample(logits.codebook_logits[:, :, i], **sampling_kwargs)[0])
+    return torch.stack(codebooks, dim=0)
+
+
+def decode_n_tokens(
+    model: Transformer,
+    cur_token: torch.Tensor,
+    input_pos: torch.Tensor,
+    num_new_tokens: int,
+    callback=lambda _: _,
+    **sampling_kwargs,
+):
+    new_tokens = []
+    for i in range(num_new_tokens):
+        with torch.backends.cuda.sdp_kernel(
+            enable_flash=False, enable_mem_efficient=False, enable_math=True
+        ):  # Actually better for Inductor to codegen attention here
+            next_token = decode_token(model, cur_token, input_pos, **sampling_kwargs)
+        input_pos += 1
+        new_tokens.append(next_token.clone())
+        callback(new_tokens[-1])
+        cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
+
+        # TODO: use tokenizer
+        if (cur_token[0, 1:, 0] == 1).any():
+            print("EOS detected, stopping generation")
+            break
+
+    return new_tokens
+
+
+def model_forward(model, x, input_pos):
+    return model(x, input_pos)
+
+
+@torch.no_grad()
+def generate(
+    model: Transformer,
+    prompt: torch.Tensor,
+    max_new_tokens: int,
+    *,
+    interactive: bool,
+    callback=lambda x: x,
+    **sampling_kwargs,
+) -> torch.Tensor:
+    """
+    Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
+    """
+
+    # create an empty tensor of the expected final shape and fill in the current tokens
+    T = prompt.size(1)
+    T_new = T + max_new_tokens
+    if interactive:
+        max_seq_length = 350
+    else:
+        max_seq_length = min(T_new, model.config.max_seq_len)
+
+    device, dtype = prompt.device, prompt.dtype
+    max_seq_length = max_seq_length
+    with torch.device(device):
+        model.setup_caches(max_batch_size=1, max_seq_len=max_seq_length)
+
+    codebook_dim = 1 + model.config.num_codebooks
+    # create an empty tensor of the expected final shape and fill in the current tokens
+    empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
+    empty[:, :T] = prompt
+    seq = empty
+    input_pos = torch.arange(0, T, device=device)
+
+    next_token = decode_token(
+        model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
+    )
+    seq[:, T : T + 1] = next_token
+
+    input_pos = torch.tensor([T], device=device, dtype=torch.int)
+    generated_tokens = decode_n_tokens(
+        model,
+        next_token.view(1, codebook_dim, -1),
+        input_pos,
+        max_new_tokens - 1,
+        callback=callback,
+        **sampling_kwargs,
+    )
+    x = torch.cat(generated_tokens, dim=1)
+    seq = seq[:, : T + 1 + x.size(1)]
+    seq[:, T + 1 :] = x
+
+    return seq
+
+
+def encode_tokens(tokenizer, string, bos=True, device="cuda"):
+    tokens = tokenizer.encode(
+        string, max_length=10**6, add_special_tokens=bos, truncation=False
+    )
+    tokens = torch.tensor([tokens], dtype=torch.int, device=device)
+
+    # Codebooks
+    zeros = torch.zeros((4, tokens.size(1)), dtype=torch.int, device=device)
+    return torch.cat((tokens, zeros), dim=0)
+
+
+def _load_model(checkpoint_path, device, precision, use_tp):
+    with torch.device("meta"):
+        # TODO: support different model archs
+        model = Transformer(
+            ModelArgs(
+                max_seq_len=4096,
+                vocab_size=32312,
+                n_layer=24,
+                n_head=16,
+                dim=1024,
+                rope_base=10000,
+                norm_eps=1e-5,
+                codebook_size=168,
+                num_codebooks=4,
+            )
+        )
+
+    if "int8" in str(checkpoint_path):
+        print("Using int8 weight-only quantization!")
+        from quantize import WeightOnlyInt8QuantHandler
+
+        simple_quantizer = WeightOnlyInt8QuantHandler(model)
+        model = simple_quantizer.convert_for_runtime()
+
+    if "int4" in str(checkpoint_path):
+        print("Using int4 quantization!")
+        path_comps = checkpoint_path.name.split(".")
+        assert path_comps[-2].startswith("g")
+        groupsize = int(path_comps[-2][1:])
+        from quantize import WeightOnlyInt4QuantHandler
+
+        simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
+        model = simple_quantizer.convert_for_runtime()
+
+    checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
+    model.load_state_dict(checkpoint, assign=True)
+
+    if use_tp:
+        from tp import apply_tp
+
+        print("Applying tensor parallel to model ...")
+        apply_tp(model)
+
+    model = model.to(device=device, dtype=precision)
+    return model.eval()
+
+
+B_INST, E_INST = "[INST]", "[/INST]"
+
+
+def main(
+    prompt: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
+    interactive: bool = False,
+    num_samples: int = 5,
+    max_new_tokens: int = 100,
+    top_k: int = 200,
+    temperature: float = 0.8,
+    checkpoint_path: Path = Path(
+        "results/text2semantic_400m/checkpoints/step_000025000.ckpt"
+    ),
+    compile: bool = True,
+    compile_prefill: bool = False,
+    profile: Optional[Path] = None,
+    tokenizer: str = "fishaudio/speech-lm-v1",
+) -> None:
+    """Generates text samples based on a pre-trained Transformer model and tokenizer."""
+    assert checkpoint_path.is_file(), checkpoint_path
+
+    global print
+    rank = maybe_init_dist()
+    use_tp = rank is not None
+    if use_tp:
+        torch.cuda.set_device(rank)
+        if rank != 0:
+            # only print on rank 0
+            print = lambda *args, **kwargs: None
+
+    device = "cuda"
+    precision = torch.bfloat16
+
+    print("Loading model ...")
+    t0 = time.time()
+    model = _load_model(checkpoint_path, device, precision, use_tp)
+
+    torch.cuda.synchronize()
+    print(f"Time to load model: {time.time() - t0:.02f} seconds")
+
+    tokenizer = AutoTokenizer.from_pretrained(tokenizer)
+    prompt = g2p(prompt)
+    prompt = [
+        (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
+        for _, i in prompt
+    ]
+    prompt = " ".join(prompt)
+    print(prompt)
+    encoded = encode_tokens(
+        tokenizer, f"[INST] {prompt} [/INST]", bos=True, device=device
+    )
+    print(encoded[0])
+    prompt_length = encoded.size(1)
+
+    torch.manual_seed(1234)
+    model_size = sum(
+        [
+            p.numel() * p.dtype.itemsize
+            for p in itertools.chain(model.parameters(), model.buffers())
+        ]
+    )
+    if compile:
+        global decode_token
+        decode_token = torch.compile(
+            decode_token, mode="reduce-overhead", fullgraph=True
+        )
+
+    aggregate_metrics = {
+        "tokens_per_sec": [],
+    }
+    start = -1 if compile else 0
+
+    for i in range(start, num_samples):
+        torch.cuda.synchronize()
+        if i >= 0 and interactive:
+            prompt = input("What is your prompt? ")
+            encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
+
+        if interactive and i >= 0:
+            buffer = []
+            period_id = tokenizer.encode(".")[0]
+            done_generating = False
+
+            def callback(x):
+                nonlocal done_generating
+                if done_generating:
+                    return
+                buffer.append(tokenizer.decode([period_id] + x.tolist())[1:])
+                if x.item() == tokenizer.eos_id():
+                    done_generating = True
+                if len(buffer) == 4 or done_generating:
+                    print("".join(buffer), end="", flush=True)
+                    buffer.clear()
+                # print(, end='', flush=True)
+
+        else:
+            callback = lambda x: x
+        t0 = time.perf_counter()
+        import contextlib
+
+        if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
+            prof = contextlib.nullcontext()
+        else:
+            torch.profiler._utils._init_for_cuda_graphs()
+            prof = torch.profiler.profile()
+        with prof:
+            y = generate(
+                model,
+                encoded,
+                max_new_tokens,
+                interactive=interactive,
+                callback=callback,
+                temperature=temperature,
+                top_k=top_k,
+            )
+        if i == -1:
+            print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
+            continue
+        if hasattr(prof, "export_chrome_trace"):
+            if use_tp:
+                prof.export_chrome_trace(f"{profile}_rank_{rank}.json")
+            else:
+                prof.export_chrome_trace(f"{profile}.json")
+        torch.cuda.synchronize()
+        t = time.perf_counter() - t0
+
+        if not interactive:
+            print(tokenizer.decode(y[0].tolist()))
+            codes = y[1:, prompt_length:-1] - 2
+            assert (codes >= 0).all()
+            import numpy as np
+
+            np.save(f"codes_{i}.npy", codes.cpu().numpy())
+        else:
+            print()
+        tokens_generated = y.size(1) - prompt_length
+        tokens_sec = tokens_generated / t
+        aggregate_metrics["tokens_per_sec"].append(tokens_sec)
+        print(
+            f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec"
+        )
+        print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
+    print("==========")
+
+    print(
+        f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}"
+    )
+    print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
+
+
+if __name__ == "__main__":
+    import argparse
+
+    parser = argparse.ArgumentParser(description="Your CLI description.")
+
+    parser.add_argument(
+        "--prompt",
+        type=str,
+        default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
+        help="Input prompt.",
+    )
+    parser.add_argument(
+        "--interactive",
+        action="store_true",
+        help="Whether to launch in interactive mode",
+    )
+    parser.add_argument("--num_samples", type=int, default=1, help="Number of samples.")
+    parser.add_argument(
+        "--max_new_tokens", type=int, default=768, help="Maximum number of new tokens."
+    )
+    parser.add_argument("--top_k", type=int, default=10, help="Top-k for sampling.")
+    parser.add_argument(
+        "--temperature", type=float, default=1.0, help="Temperature for sampling."
+    )
+    parser.add_argument(
+        "--checkpoint_path",
+        type=Path,
+        default=Path("results/text2semantic_400m/step_000025000_weights.ckpt"),
+        help="Model checkpoint path.",
+    )
+    parser.add_argument(
+        "--compile", action="store_true", help="Whether to compile the model."
+    )
+    parser.add_argument("--profile", type=Path, default=None, help="Profile path.")
+
+    args = parser.parse_args()
+    main(
+        args.prompt,
+        args.interactive,
+        args.num_samples,
+        args.max_new_tokens,
+        args.top_k,
+        args.temperature,
+        args.checkpoint_path,
+        args.compile,
+        args.profile,
+    )

+ 40 - 0
fish_speech/models/text2semantic/llama.py

@@ -164,6 +164,46 @@ class Transformer(nn.Module):
             codebook_logits=codebook_logits,
         )
 
+    def forward_generate(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
+        # x: (batch, num_codebooks + 1, 1)
+
+        assert (
+            self.max_seq_len != -1 and self.max_batch_size != -1
+        ), "Please call setup_caches before forward_generate"
+
+        # Here we want to merge the embeddings of the codebooks
+        vocab_embeds = [self.embeddings(x[:, 0])]
+        for i in range(self.config.num_codebooks):
+            emb = self.embeddings(
+                x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
+            )
+            vocab_embeds.append(emb)
+
+        x = torch.stack(vocab_embeds, dim=3)
+        x = x.mean(dim=3)
+
+        mask = self.causal_mask[
+            None, None, input_pos, : self.max_seq_len
+        ]  # (B, N, Q, K)
+        freqs_cis = self.freqs_cis[input_pos]
+
+        for layer in self.layers:
+            x = layer(x, freqs_cis, mask, input_pos=input_pos)
+
+        x = self.norm(x)
+        logits = self.output(x)
+        token_logits = logits[:, :, : self.config.vocab_size]
+        codebook_logits = logits[:, :, self.config.vocab_size :]
+
+        codebook_logits = rearrange(
+            codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
+        )
+
+        return TransformerForwardResult(
+            token_logits=token_logits,
+            codebook_logits=codebook_logits,
+        )
+
 
 class TransformerBlock(nn.Module):
     def __init__(self, config: ModelArgs) -> None:

+ 315 - 480
fish_speech/models/text2semantic/modules.py

@@ -5,206 +5,145 @@ import torch
 from einops import rearrange
 from torch import nn
 from torch.nn import functional as F
+from transformers.modeling_attn_mask_utils import AttentionMaskConverter
 
-try:
-    from xformers.ops import memory_efficient_attention
-except ImportError as e:
-    memory_efficient_attention = None
 
+def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
+    """
+    Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
 
-class AlibiPostionEmbedding(nn.Module):
-    def __init__(self, nheads, maxpos):
-        super().__init__()
+    This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
+    and the end index 'end'. The 'theta' parameter scales the frequencies.
+    The returned tensor contains complex values in complex64 data type.
 
-        context_position = torch.arange(maxpos)[:, None]
-        memory_position = torch.arange(maxpos)[None, :]
-        relative_position = memory_position - context_position
-        relative_position = (
-            torch.abs(relative_position).unsqueeze(0).expand(nheads, -1, -1)
-        )
-        self.slopes = torch.Tensor(self.get_slopes(nheads)) * -1
-        alibi = self.slopes.unsqueeze(1).unsqueeze(1) * relative_position
-        alibi = alibi.view(nheads, maxpos, maxpos)
-
-        self.register_buffer("alibi", alibi)
-
-    @staticmethod
-    def get_slopes_power_of_2(n):
-        start = 2 ** (-(2 ** -(math.log2(n) - 3)))
-        ratio = start
-        return [start * ratio**i for i in range(n)]
-
-    def get_slopes(self, n):
-        if math.log2(n).is_integer():
-            return self.get_slopes_power_of_2(n)
-
-        closest_power_of_2 = 2 ** math.floor(math.log2(n))
-        return (
-            self.get_slopes_power_of_2(closest_power_of_2)
-            + self.get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
-        )
+    Args:
+        dim (int): Dimension of the frequency tensor.
+        end (int): End index for precomputing frequencies.
+        theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
 
-    def __call__(self, x):
-        # N, T, C
-        return self.alibi[:, : x.size(1), : x.size(1)].to(x.device)
+    Returns:
+        torch.Tensor: Precomputed frequency tensor with complex exponentials.
+    """
+    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
+    t = torch.arange(end, device=freqs.device)  # type: ignore
+    freqs = torch.outer(t, freqs).float()  # type: ignore
+    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
+    return freqs_cis
 
 
-class KVCache(nn.Module):
-    def __init__(
-        self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16
-    ):
-        super().__init__()
-        cache_shape = (max_batch_size, max_seq_length, n_heads * head_dim)
-        self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
-        self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
+def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
+    """
+    Reshape frequency tensor for broadcasting it with another tensor.
+
+    This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
+    for the purpose of broadcasting the frequency tensor during element-wise operations.
+
+    Args:
+        freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
+        x (torch.Tensor): Target tensor for broadcasting compatibility.
+
+    Returns:
+        torch.Tensor: Reshaped frequency tensor.
 
-    def update(self, input_pos, k_val, v_val):
-        assert input_pos is not None, "input_pos should not be None"
+    Raises:
+        AssertionError: If the frequency tensor doesn't match the expected shape.
+        AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
+    """
+    ndim = x.ndim
+    assert 0 <= 1 < ndim
+    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
+    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
+    return freqs_cis.view(*shape)
 
-        k_out = self.k_cache
-        v_out = self.v_cache
-        k_out[:, input_pos] = k_val
-        v_out[:, input_pos] = v_val
 
-        return k_out, v_out
+def apply_rotary_emb(
+    x: torch.Tensor,
+    freqs_cis: torch.Tensor,
+) -> tuple[torch.Tensor, torch.Tensor]:
+    x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
+    freqs_cis = reshape_for_broadcast(freqs_cis, x_)
+    return torch.view_as_real(x_ * freqs_cis).flatten(3).type_as(x)
 
 
 class MultiheadAttention(nn.Module):
-    def __init__(self, d_model, nhead, dropout=0.1):
+    def __init__(self, d_model, nhead, dropout=0.1, is_cross_attention=False):
         super().__init__()
         assert d_model % nhead == 0
         self.nhead = nhead
         self.d_model = d_model
         self.head_dim = d_model // nhead
+        self.is_cross_attention = is_cross_attention
 
-        self.q_proj = nn.Linear(d_model, d_model)
-        self.k_proj = nn.Linear(d_model, d_model)
-        self.v_proj = nn.Linear(d_model, d_model)
-        self.out_proj = nn.Linear(d_model, d_model)
+        # Auto fuse linear projection
+        if is_cross_attention:
+            self.q_proj = nn.Linear(d_model, d_model)
+            self.kv_proj = nn.Linear(d_model, d_model * 2)
+        else:
+            self.qkv_proj = nn.Linear(d_model, d_model * 3)
+
+        self.o_proj = nn.Linear(d_model, d_model)
         self.dropout = nn.Dropout(dropout)
-        self.kv_cache = None
 
     def forward(
         self,
         q,
-        k,
-        v,
+        freqs_cis_q,
+        kv=None,
+        freqs_cis_kv=None,
         attn_mask=None,
-        key_padding_mask=None,
-        attn_bias=None,
-        return_weights=False,
         input_pos=None,
+        kv_cache=None,
     ):
-        # (B, T, C)
-        batch_size = q.size(0)
-        q_length = q.size(1)
-
-        q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)
-
-        if self.kv_cache is not None:
-            k, v = self.kv_cache.update(input_pos, k, v)
-
-        k_length = k.size(1)
-
-        if attn_bias is not None:
-            assert attn_bias.size() == (
-                self.nhead,
-                q_length,
-                k_length,
-            ), f"Should be {(self.nhead, q_length, k_length)}. Got {attn_bias.size()}"
-
-            attn_bias = attn_bias.unsqueeze(0).expand(batch_size, -1, -1, -1)
-
-        if attn_mask is not None:
-            assert attn_mask.size() == (
-                q_length,
-                k_length,
-            ), f"Should be {(q_length, k_length)}. Got {attn_mask.size()}"
-            assert attn_mask.dtype == torch.bool
-            attn_mask = attn_mask.unsqueeze(0).expand(batch_size * self.nhead, -1, -1)
-
-        if key_padding_mask is not None:
-            assert key_padding_mask.size() == (
-                batch_size,
-                k_length,
-            ), f"Should be {(batch_size, k_length)}. Got {key_padding_mask.size()}"
-            assert key_padding_mask.dtype == torch.bool
-            key_padding_mask = (
-                key_padding_mask.unsqueeze(1)
-                .unsqueeze(1)
-                .expand(-1, self.nhead, -1, -1)
-            )
-            key_padding_mask = key_padding_mask.reshape(
-                batch_size * self.nhead, 1, k_length
-            )
-            if attn_mask is None:
-                attn_mask = key_padding_mask.expand(-1, q.size(1), -1)
+        if self.is_cross_attention:
+            q = self.q_proj(q)
+            if kv is None:
+                assert self.kv_cache is not None, "kv_cache should be initialized"
+                k, v = None
             else:
-                attn_mask = attn_mask.logical_or(key_padding_mask)
-
-        if (
-            return_weights is False
-            and memory_efficient_attention is not None
-            and q.device.type == "cuda"
-        ):
-            # (-> b, t,. n, d)
-            q = rearrange(q, "b t (n d) -> b t n d", n=self.nhead)
-            k = rearrange(k, "b t (n d) -> b t n d", n=self.nhead)
-            v = rearrange(v, "b t (n d) -> b t n d", n=self.nhead)
-
-            if attn_mask is not None:
-                attn_mask = rearrange(attn_mask, "(b n) q k -> b n q k", n=self.nhead)
-
-                if attn_bias is None:
-                    attn_bias = torch.zeros_like(
-                        attn_mask, dtype=q.dtype, device=q.device
-                    )
-                attn_bias = attn_bias.masked_fill(attn_mask, float("-inf"))
-
-            if attn_bias is not None:
-                attn_bias = attn_bias.to(q.dtype)
-
-            attn_output = memory_efficient_attention(
-                q,
-                k,
-                v,
-                attn_bias=attn_bias,
-                scale=self.head_dim**-0.5,
-                p=self.dropout.p,
-            )
-            attn_output = rearrange(attn_output, "b t n d -> b t (n d)", n=self.nhead)
-
-            returned_weights = None
+                # Using kv cache
+                kv = self.kv_proj(kv)
+                k, v = torch.chunk(kv, 2, dim=-1)
         else:
-            q = rearrange(q, "b t (n d) -> (b n) t d", n=self.nhead)
-            k = rearrange(k, "b t (n d) -> (b n) t d", n=self.nhead)
-            v = rearrange(v, "b t (n d) -> (b n) t d", n=self.nhead)
-
-            attn_weights = torch.bmm(q, k.mT) * (self.head_dim**-0.5)
-            assert attn_weights.size() == (
-                batch_size * self.nhead,
-                q.size(1),
-                k.size(1),
-            )
-
-            if attn_bias is not None:
-                attn_bias = rearrange(attn_bias, "b n q k -> (b n) q k")
-                attn_weights = attn_weights + attn_bias
-
-            if attn_mask is not None:
-                attn_weights = attn_weights.masked_fill(attn_mask, float("-inf"))
-
-            attn_weights = F.softmax(attn_weights, dim=-1, dtype=attn_weights.dtype)
-            returned_weights = attn_weights.view(
-                batch_size, self.nhead, q.size(1), k.size(1)
-            )
-
-            attn_probs = self.dropout(attn_weights)
-            attn_output = torch.bmm(attn_probs, v)
-            attn_output = rearrange(attn_output, "(b n) t d -> b t (n d)", n=self.nhead)
+            assert kv is None, f"kv should be None for self attention"
+            assert (
+                freqs_cis_kv is None
+            ), f"freqs_cis_kv should be None for self attention"
+            q, k, v = torch.chunk(self.qkv_proj(q), 3, dim=-1)
+
+        # max_batch_size, max_seq_length, n_heads, head_dim
+        q = rearrange(q, "b t (h d) -> b t h d", h=self.nhead, d=self.head_dim)
+        q = apply_rotary_emb(q, freqs_cis_q)
+
+        if freqs_cis_kv is None:
+            freqs_cis_kv = freqs_cis_q
+
+        # Only do when self attention or cross attention without kv cache
+        if k is not None:
+            assert v is not None, "v should not be None when k is not None"
+            k = rearrange(k, "b t (h d) -> b t h d", h=self.nhead, d=self.head_dim)
+            v = rearrange(v, "b t (h d) -> b t h d", h=self.nhead, d=self.head_dim)
+            k = apply_rotary_emb(k, freqs_cis_kv)
+
+        if kv_cache is not None:
+            if k is None:
+                assert v is None, "v should be None when k is None"
+                k, v = kv_cache[0], kv_cache[1]
+            else:
+                k = torch.cat([kv_cache[0], k], dim=1)
+                v = torch.cat([kv_cache[1], v], dim=1)
+                kv_cache = (k, v)
+
+        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
+        value = F.scaled_dot_product_attention(
+            q,
+            k,
+            v,
+            attn_mask=attn_mask,
+            dropout_p=self.dropout.p if self.training else 0,
+        )
 
-        attn_output = self.out_proj(attn_output)
-        return attn_output, returned_weights
+        value = rearrange(value, "b h t d -> b t (h d)")
+        return self.o_proj(value), kv_cache
 
 
 class GluMLP(nn.Module):
@@ -246,76 +185,80 @@ class RMSNorm(nn.Module):
         return self.weight * hidden_states.to(input_dtype)
 
 
-class CrossAttentionLayer(nn.Module):
-    def __init__(self, hidden_size=1024, intermediate_size=None, dropout=0.1):
+class TransformerEncoderLayer(nn.Module):
+    def __init__(self, hidden_size=1024, intermediate_size=None, nhead=16, dropout=0.1):
         super().__init__()
 
-        self.attn = MultiheadAttention(hidden_size, 1, dropout=dropout)
-        self.mlp = GluMLP(hidden_size=hidden_size, intermediate_size=intermediate_size)
-        self.input_layernorm_q = RMSNorm(hidden_size, eps=1e-6)
-        self.input_layernorm_kv = RMSNorm(hidden_size, eps=1e-6)
-        self.post_attention_layernorm = RMSNorm(hidden_size, eps=1e-6)
+        self.attention = MultiheadAttention(hidden_size, nhead, dropout=dropout)
+        self.ffn = GluMLP(hidden_size=hidden_size, intermediate_size=intermediate_size)
+
+        self.attention_norm = RMSNorm(hidden_size, eps=1e-6)
+        self.ffn_norm = RMSNorm(hidden_size, eps=1e-6)
 
     def forward(
         self,
-        tgt,
-        memory,
-        memory_key_padding_mask=None,
+        x,
+        freqs_cis,
+        attn_mask=None,
         input_pos=None,
     ):
-        residual = tgt
-        tgt, memory = self.input_layernorm_q(tgt), self.input_layernorm_kv(memory)
-        x, attn_weights = self.attn(
-            tgt,
-            memory,
-            memory,
-            key_padding_mask=memory_key_padding_mask,
-            return_weights=True,
-            input_pos=input_pos,
+        x = (
+            x
+            + self.attention(
+                q=self.attention_norm(x),
+                freqs_cis_q=freqs_cis,
+                attn_mask=attn_mask,
+                input_pos=input_pos,
+            )[0]
         )
-        residual = x + residual
 
-        x = self.post_attention_layernorm(residual)
-        x = self.mlp(x)
-        x = x + residual
+        return x + self.ffn(self.ffn_norm(x))
 
-        return x, attn_weights
 
-
-class TransformerEncoderLayer(nn.Module):
+class TransformerDecoderLayer(nn.Module):
     def __init__(self, hidden_size=1024, intermediate_size=None, nhead=16, dropout=0.1):
         super().__init__()
 
-        self.attn = MultiheadAttention(hidden_size, nhead, dropout=dropout)
-        self.mlp = GluMLP(hidden_size=hidden_size, intermediate_size=intermediate_size)
-        self.input_layernorm = RMSNorm(hidden_size, eps=1e-6)
-        self.post_attention_layernorm = RMSNorm(hidden_size, eps=1e-6)
+        self.self_attention = MultiheadAttention(hidden_size, nhead, dropout=dropout)
+        self.cross_attention = MultiheadAttention(
+            hidden_size, nhead, dropout=dropout, is_cross_attention=True
+        )
+        self.ffn = GluMLP(hidden_size=hidden_size, intermediate_size=intermediate_size)
+
+        self.self_attention_norm = RMSNorm(hidden_size, eps=1e-6)
+        self.cross_attention_norm = RMSNorm(hidden_size, eps=1e-6)
+        self.ffn_norm = RMSNorm(hidden_size, eps=1e-6)
 
     def forward(
-        self, x, attn_bias=None, key_padding_mask=None, tgt_mask=None, input_pos=None
+        self,
+        x,
+        context,
+        freqs_cis_q,
+        freqs_cis_kv,
+        self_attn_mask=None,
+        cross_attn_mask=None,
+        input_pos=None,
     ):
-        residual = x
-        x = self.input_layernorm(x)
-        x, _ = self.attn(
-            x,
-            x,
-            x,
-            attn_bias=attn_bias,
-            key_padding_mask=key_padding_mask,
-            attn_mask=tgt_mask,
-            return_weights=False,
+        x = x + self.self_attention(
+            q=self.self_attention_norm(x),
+            freqs_cis_q=freqs_cis_q,
+            attn_mask=self_attn_mask,
             input_pos=input_pos,
         )
-        residual = x + residual
 
-        x = self.post_attention_layernorm(residual)
-        x = self.mlp(x)
-        x = x + residual
+        x = x + self.cross_attention(
+            q=self.cross_attention_norm(x),
+            kv=context,
+            freqs_cis_q=freqs_cis_q,
+            freqs_cis_kv=freqs_cis_kv,
+            attn_mask=cross_attn_mask,
+            input_pos=input_pos,
+        )
 
-        return x
+        return x + self.ffn(self.ffn_norm(x))
 
 
-class FishSpeechTransformer(nn.Module):
+class Transformer(nn.Module):
     def __init__(
         self,
         vocab_size,
@@ -327,8 +270,7 @@ class FishSpeechTransformer(nn.Module):
         num_encoder_layers=12,
         num_decoder_layers=12,
         dropout=0.1,
-        alignment_position=-2,
-        max_position=8192,
+        max_position=4096,
     ):
         super().__init__()
 
@@ -339,6 +281,7 @@ class FishSpeechTransformer(nn.Module):
         self.decoder_head = nn.Linear(hidden_size, codebook_size * num_codebooks)
         self.codebook_size = codebook_size
         self.num_codebooks = num_codebooks
+        self.nhead = nhead
 
         self.encoder = nn.ModuleList(
             [
@@ -352,21 +295,9 @@ class FishSpeechTransformer(nn.Module):
             ]
         )
 
-        self.alignment = CrossAttentionLayer(
-            hidden_size=hidden_size,
-            intermediate_size=intermediate_size,
-            dropout=dropout,
-        )
-
-        if alignment_position < 0:
-            alignment_position = num_decoder_layers + alignment_position
-
-        self.alignment_position = alignment_position
-        assert 0 <= alignment_position < num_decoder_layers
-
         self.decoder = nn.ModuleList(
             [
-                TransformerEncoderLayer(
+                TransformerDecoderLayer(
                     hidden_size=hidden_size,
                     intermediate_size=intermediate_size,
                     nhead=nhead,
@@ -376,12 +307,21 @@ class FishSpeechTransformer(nn.Module):
             ]
         )
 
-        self.alibi = AlibiPostionEmbedding(nhead, max_position)
         self.register_buffer(
-            "causual_mask",
-            torch.triu(torch.ones(max_position, max_position), diagonal=1).bool(),
+            "freqs_cis",
+            precompute_freqs_cis(hidden_size // nhead, max_position, theta=10000.0),
         )
 
+        causual_mask = torch.triu(
+            torch.ones(max_position, max_position), diagonal=1
+        ).bool()
+        causual_mask = torch.zeros(max_position, max_position).masked_fill(
+            causual_mask, float("-inf")
+        )
+
+        self.register_buffer("causual_mask", causual_mask)
+
+        # The following are reserved for kv cache
         self.max_batch_size = -1
         self.max_seq_length = -1
 
@@ -399,284 +339,156 @@ class FishSpeechTransformer(nn.Module):
         self.max_batch_size = max_batch_size
 
         for b in self.decoder:
-            b.attn.kv_cache = KVCache(
-                max_batch_size, max_seq_length, b.attn.nhead, b.attn.head_dim
-            )
-
-    def forward(self, inputs, codes, input_mask=None, codes_mask=None):
-        # x: (B, T)
-        # y: (B, C, T)
-        inputs = self.encoder_embedding(inputs)
-        codes = rearrange(codes, "b c t -> c b t")
-        codes = torch.stack(
-            [emb(code) for emb, code in zip(self.decoder_embeddings, codes)], dim=0
+            b.self_attention.kv_cache = KVCache(
+                max_batch_size,
+                max_seq_length,
+                b.self_attention.nhead,
+                b.self_attention.head_dim,
+            ).to(b.self_attention_norm.weight.device)
+
+            b.cross_attention.kv_cache = KVCache(
+                max_batch_size,
+                max_seq_length,
+                b.cross_attention.nhead,
+                b.cross_attention.head_dim,
+            ).to(b.cross_attention_norm.weight.device)
+
+    def get_key_padding_mask(self, key_padding_mask, q_size=None):
+        # inputs: (B, T) bool ->
+        assert key_padding_mask.dtype == torch.bool and key_padding_mask.ndim == 2
+
+        key_padding_mask = (
+            key_padding_mask.unsqueeze(1).unsqueeze(1).expand(-1, self.nhead, -1, -1)
         )
-        codes = torch.mean(codes, dim=0)  # (B, T)
-
-        attn_bias = self.alibi(inputs)
-        for layer in self.encoder:
-            inputs = layer(inputs, attn_bias=attn_bias, key_padding_mask=input_mask)
 
-        attn_bias = self.alibi(codes)
-        causual_mask = self.causual_mask[: codes.shape[1], : codes.shape[1]]
-
-        for idx, layer in enumerate(self.decoder):
-            if idx == self.alignment_position:
-                codes, _ = self.alignment(
-                    codes, inputs, memory_key_padding_mask=input_mask
-                )
-
-            codes = layer(
-                codes,
-                attn_bias=attn_bias,
-                key_padding_mask=codes_mask,
-                tgt_mask=causual_mask,
-            )
-
-        codes = self.decoder_head(codes)
-        codes = rearrange(
-            codes, "b t (c d) -> b c t d", c=self.num_codebooks, d=self.codebook_size
+        key_padding_mask = key_padding_mask.reshape(
+            key_padding_mask.shape[0], self.nhead, 1, key_padding_mask.shape[1]
         )
 
-        return codes
-
-    def sample_decoder(
-        self,
-        x: torch.Tensor,
-        context: torch.Tensor,
-        input_pos: torch.Tensor,
-        **sampling_kwargs,
-    ):
-        attn_bias = self.alibi.alibi[:, input_pos, : self.max_seq_length]
-        causual_mask = self.causual_mask[input_pos, : self.max_seq_length]
+        if q_size is not None:
+            key_padding_mask = key_padding_mask.expand(-1, -1, q_size, -1)
 
-        x = rearrange(x, "b c t -> c b t")
-        x = torch.stack(
-            [emb(code) for emb, code in zip(self.decoder_embeddings, x)], dim=0
+        new_mask = torch.zeros(
+            *key_padding_mask.shape, dtype=torch.float, device=key_padding_mask.device
         )
-        x = torch.mean(x, dim=0)  # (B, T)
+        new_mask = new_mask.masked_fill(key_padding_mask, float("-inf"))
 
-        for idx, layer in enumerate(self.decoder):
-            if idx == self.alignment_position:
-                x, _ = self.alignment(x, context)
-
-            x = layer(
-                x, attn_bias=attn_bias, input_pos=input_pos, tgt_mask=causual_mask
-            )
+        return new_mask
 
-        x = self.decoder_head(x)
-        x = rearrange(
-            x, "b t (c d) -> b c t d", c=self.num_codebooks, d=self.codebook_size
-        )
+    def forward_encoder(
+        self, inputs, input_mask=None
+    ) -> tuple[torch.Tensor, torch.Tensor]:
+        # inputs: (B, T)
+        # input_mask: (B, T), bool mask
+        inputs = self.encoder_embedding(inputs)
 
-        # Never predict EOS or BOS for sub-codebooks
-        x[:, 1:, :2] = -float("Inf")
-
-        next_token, probs = [], []
-        for i in range(self.num_codebooks):
-            next_token_i, probs_i = self.sample(x[:, i], **sampling_kwargs)
-            next_token.append(next_token_i)
-            probs.append(probs_i)
-
-        return torch.stack(next_token, dim=0), torch.stack(probs, dim=0)
-
-    @staticmethod
-    def multinomial_sample_one_no_sync(
-        probs_sort,
-    ):  # Does multinomial sampling without a cuda synchronization
-        q = torch.empty_like(probs_sort).exponential_(1)
-        return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
-
-    @staticmethod
-    def logits_to_probs(
-        logits,
-        temperature: float = 1.0,
-        top_p: Optional[int] = None,
-        top_k: Optional[int] = None,
-    ):
-        if top_p is not None:
-            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
-            cum_probs = torch.cumsum(
-                torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
+        # Calculate mask
+        if input_mask is None:
+            # Assume no padding
+            input_mask = torch.zeros(
+                inputs.shape[0], inputs.shape[1], dtype=torch.bool, device=inputs.device
             )
-            sorted_indices_to_remove = cum_probs > top_p
-            sorted_indices_to_remove[0] = False  # keep at least one option
-            indices_to_remove = sorted_indices_to_remove.scatter(
-                dim=0, index=sorted_indices, src=sorted_indices_to_remove
-            )
-            logits = logits.masked_fill(indices_to_remove, -float("Inf"))
 
-        logits = logits / max(temperature, 1e-5)
+        input_mask = self.get_key_padding_mask(input_mask, q_size=None).to(inputs.dtype)
 
-        if top_k is not None:
-            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
-            pivot = v.select(-1, -1).unsqueeze(-1)
-            logits = torch.where(logits < pivot, -float("Inf"), logits)
+        freqs_cis = self.freqs_cis[: inputs.shape[1]]
+        input_mask_self = input_mask.expand(-1, -1, inputs.shape[1], -1)
 
-        probs = torch.nn.functional.softmax(logits, dim=-1)
-        return probs
+        for layer in self.encoder:
+            inputs = layer(inputs, freqs_cis=freqs_cis, attn_mask=input_mask_self)
 
-    def sample(
-        self,
-        logits,
-        temperature: float = 1.0,
-        top_p: Optional[int] = None,
-        top_k: Optional[int] = None,
-    ):
-        probs = self.logits_to_probs(logits[0, -1], temperature, top_p, top_k)
-        idx_next = self.multinomial_sample_one_no_sync(probs)
-        return idx_next, probs
+        return inputs, input_mask
 
-    def decode_n_tokens(
-        self,
-        cur_token: torch.Tensor,
-        context: torch.Tensor,
-        input_pos: torch.Tensor,
-        num_new_tokens: int,
-        callback=lambda _: _,
-        **sampling_kwargs,
+    def forward_decoder(
+        self, codes, inputs, input_mask, codes_mask=None, input_pos=None
     ):
-        new_tokens, new_probs = [], []
-        # Sliding context window
-        batch_size = 1
-        back_map = torch.zeros(
-            [batch_size, 1], device=cur_token.device, dtype=torch.long
-        )
-
-        for i in range(num_new_tokens):
-            next_token, next_prob = self.sample_decoder(
-                cur_token, context, input_pos, **sampling_kwargs
-            )
-
-            # index_map = torch.arange(6, device=cur_token.device)
-            # index_map = back_map[:, -1:] + index_map.repeat(batch_size, 1)
-            # add = torch.arange(batch_size, device=index_map.device).unsqueeze(1) #N, 1
-            # index_map = index_map + add * t_length
-
-            input_pos += 1
-            new_tokens.append(next_token.clone())
-            callback(new_tokens[-1])
-            new_probs.append(next_prob.clone())
-
-            if next_token[0, 0] == 1:
-                break
-
-            cur_token = next_token.view(1, self.num_codebooks, -1)
+        # codes: (B, C, T)
+        # inputs: (B, T, N)
 
-        return new_tokens, new_probs
-
-    def compile(self):
-        self.sampler_decoder = torch.compile(
-            self.sample_decoder, mode="reduce-overhead", fullgraph=True
+        print(f"Codes: {codes.shape}, Inputs: {inputs.shape}")
+        codes = rearrange(codes, "b c t -> c b t")
+        codes = torch.stack(
+            [emb(code) for emb, code in zip(self.decoder_embeddings, codes)], dim=0
         )
+        codes = torch.mean(codes, dim=0)  # (B, T)
 
-    @torch.no_grad()
-    def inference(self, inputs, prompt=None, max_new_tokens=1024, **sampling_kwargs):
-        # inputs: (B, T)
-        # prompt: (B, C, T)
-
-        assert inputs.size(0) == 1, "Only support batch size 1 for now"
+        # If kv cache is enabled
+        input_mask = input_mask.expand(-1, -1, codes.shape[1], -1)
 
-        if prompt is None:
-            prompt = torch.tensor(
-                [[[0]] * self.num_codebooks], device=inputs.device, dtype=torch.long
-            )
+        # Calculate mask
+        if input_pos is not None:
+            attn_mask = self.causual_mask[: codes.shape[1], : codes.shape[1]]
+        else:
+            attn_mask = None
 
-        T = prompt.size(2)
-        T_new = T + max_new_tokens
+        # if codes_mask is not None:
+        #     codes_mask = self.get_key_padding_mask(codes_mask)
+        #     attn_mask = attn_mask + codes_mask
 
-        # Encode Features
-        inputs = self.encoder_embedding(inputs)
-        attn_bias = self.alibi(inputs)
-        for layer in self.encoder:
-            inputs = layer(inputs, attn_bias=attn_bias)
+        # For kv cache
+        if input_pos is not None:
+            freqs_cis_q = self.freqs_cis[input_pos]
+        else:
+            freqs_cis_q = self.freqs_cis[: codes.shape[1]]
 
-        device, dtype = inputs.device, inputs.dtype
+        freqs_cis_kv = self.freqs_cis[: inputs.shape[1]]
 
-        # Decode
-        with torch.device(inputs.device):
-            self.setup_kv_caches(max_batch_size=1, max_seq_length=T_new)
+        for layer in self.decoder:
+            codes = layer(
+                codes,
+                inputs,
+                freqs_cis_q=freqs_cis_q,
+                freqs_cis_kv=freqs_cis_kv,
+                self_attn_mask=attn_mask,
+                cross_attn_mask=input_mask,
+                input_pos=input_pos,
+            )
 
-        # create an empty tensor of the expected final shape and fill in the current tokens
-        empty = torch.empty(
-            (1, self.num_codebooks, T_new), dtype=torch.long, device=device
+        codes = self.decoder_head(codes)
+        codes = rearrange(
+            codes, "b t (c d) -> b c t d", c=self.num_codebooks, d=self.codebook_size
         )
-        empty[:, :, :T] = prompt
-        seq = empty
-        input_pos = torch.arange(0, T, device=device)
 
-        # prefill
-        next_token, _ = self.sample_decoder(
-            prompt.view(1, self.num_codebooks, -1), inputs, input_pos, **sampling_kwargs
-        )
-        seq[:, :, T] = next_token
+        return codes
 
-        # create an empty tensor of the expected final shape and fill in the current tokens
-        input_pos = torch.tensor([T], device=device, dtype=torch.long)
-        generated_tokens, _ = self.decode_n_tokens(
-            next_token.view(1, self.num_codebooks, -1),
-            context=inputs,
-            input_pos=input_pos,
-            num_new_tokens=max_new_tokens - 1,
-            **sampling_kwargs,
-        )
+    def forward(
+        self,
+        inputs,
+        codes,
+        input_mask=None,
+        codes_mask=None,
+        input_pos=None,
+    ):
+        # inputs: (B, T)
+        # codes: (B, C, T)
+        # input_mask: (B, T), bool mask
+        # codes_mask: (B, T), bool mask
+        # input_pos: (B, T), int mask
 
-        generated_tokens = torch.stack(generated_tokens, dim=-1)
-        seq = seq[:, :, : T + 1 + generated_tokens.size(-1)]
-        seq[:, :, T + 1 :] = generated_tokens
+        inputs, input_mask = self.forward_encoder(inputs, input_mask)
+        codes = self.forward_decoder(codes, inputs, input_mask, codes_mask, input_pos)
 
-        return seq
+        return codes
 
 
 if __name__ == "__main__":
-    # mha = MultiheadAttention(512, 8, dropout=0)
-    # mha.eval()
-    # mha.cuda()
-
-    # q, k, v = torch.randn(3, 10, 16, 512)
-    # q, k, v = q.cuda(), k.cuda(), v.cuda()
-    # alibi = AlibiPostionEmbedding(8, 1024)
-
-    # mha.bfloat16()
-    # q, k, v = q.bfloat16(), k.bfloat16(), v.bfloat16()
-    # bias = alibi(q).bfloat16()
-
-    # # Causual mask
-    # attn_mask = torch.triu(torch.ones(16, 16), diagonal=1).bool().cuda()
-    # o, w = mha(q, k, v, return_weights=True, attn_bias=bias, attn_mask=attn_mask)
-
-    # print(o.size())
-    # print(w.size())
+    mha = MultiheadAttention(512, 8, dropout=0, is_cross_attention=True)
+    mha.eval()
+    mha.cuda()
 
-    # o1, w = mha(q, k, v, return_weights=False, attn_bias=bias, attn_mask=attn_mask)
-    # print(o1.size())
+    q, kv = torch.randn(2, 10, 16, 512)
+    q, kv = q.cuda(), kv.cuda()
 
-    # print(o[0], o1.float()[0])
+    mha.bfloat16()
+    q, kv = q.bfloat16(), kv.bfloat16()
+    freqs_cis = precompute_freqs_cis(512 // 8, 4096 * 2).cuda()[:16]
 
-    # assert torch.allclose(o.float(), o1.float(), atol=1e-2, rtol=1e-2)
-    # print("ok")
-
-    # cross = CrossAttentionLayer(512, 1024, dropout=0)
-    # cross.eval()
-    # cross.cuda()
-
-    # tgt = torch.randn(3, 10, 512).cuda()
-    # memory = torch.randn(3, 20, 512).cuda()
-    # o, w = cross(tgt, memory)
-
-    # print(o.size())
-    # print(w.size())
-
-    # ten = TransformerEncoderLayer(512, 1024, 8, dropout=0)
-    # ten.eval()
-    # ten.cuda()
-
-    # tgt = torch.randn(3, 10, 512).cuda()
-    # o = ten(tgt)
-    # print(o.size())
+    # Causual mask
+    attn_mask = torch.triu(torch.ones(16, 16), diagonal=1).bool().cuda()
+    o = mha(q, freqs_cis, kv=kv, attn_mask=attn_mask)
 
     trans = (
-        FishSpeechTransformer(
+        Transformer(
             vocab_size=30000,
             codebook_size=120,
             num_codebooks=4,
@@ -689,11 +501,34 @@ if __name__ == "__main__":
         .bfloat16()
         .cuda()
     )
+    trans.eval()
+
     # Print n param
     print("Total params:", sum(i.numel() for i in trans.parameters()) / 1024 / 1024)
-    inputs = torch.randint(0, 1000, (1, 16)).cuda()
-    codes = torch.randint(0, 120, (1, 4, 128)).cuda()
-    print(trans(inputs, codes).size())
+    inputs = torch.randint(0, 1000, (2, 16)).cuda()
+    codes = torch.randint(0, 120, (2, 4, 128)).cuda()
+    x = trans(inputs, codes)
+    x1 = trans(inputs, codes)
+
+    assert torch.allclose(x, x1, atol=1e-4, rtol=1e-3), "Model is not deterministic"
+    print("Model is deterministic")
+
+    # Test kv cache
+    trans.setup_kv_caches(2, 1024)
+    inputs, inputs_mask = trans.forward_encoder(inputs)
+
+    outputs = []
+
+    for i in range(128):
+        code = codes[..., i].unsqueeze(-1)
+        code_mask = torch.tensor([[1], [1]], dtype=torch.bool, device=code.device)
+        input_pos = torch.tensor([i], dtype=torch.long, device=code.device)
+        outputs.append(
+            trans.forward_decoder(
+                code, inputs, inputs_mask, code_mask, input_pos=input_pos
+            )
+        )
 
-    r = trans.inference(inputs, max_new_tokens=1024, top_k=5, temperature=0.3)
-    print(r)
+    outputs = torch.cat(outputs, dim=2)
+    print(x.shape, outputs.shape)
+    assert torch.allclose(x, outputs, atol=1e-4, rtol=1e-3), "KV cache is not working"

+ 699 - 0
fish_speech/models/text2semantic/modules_old.py

@@ -0,0 +1,699 @@
+import math
+from typing import Optional
+
+import torch
+from einops import rearrange
+from torch import nn
+from torch.nn import functional as F
+
+try:
+    from xformers.ops import memory_efficient_attention
+except ImportError as e:
+    memory_efficient_attention = None
+
+
+class AlibiPostionEmbedding(nn.Module):
+    def __init__(self, nheads, maxpos):
+        super().__init__()
+
+        context_position = torch.arange(maxpos)[:, None]
+        memory_position = torch.arange(maxpos)[None, :]
+        relative_position = memory_position - context_position
+        relative_position = (
+            torch.abs(relative_position).unsqueeze(0).expand(nheads, -1, -1)
+        )
+        self.slopes = torch.Tensor(self.get_slopes(nheads)) * -1
+        alibi = self.slopes.unsqueeze(1).unsqueeze(1) * relative_position
+        alibi = alibi.view(nheads, maxpos, maxpos)
+
+        self.register_buffer("alibi", alibi)
+
+    @staticmethod
+    def get_slopes_power_of_2(n):
+        start = 2 ** (-(2 ** -(math.log2(n) - 3)))
+        ratio = start
+        return [start * ratio**i for i in range(n)]
+
+    def get_slopes(self, n):
+        if math.log2(n).is_integer():
+            return self.get_slopes_power_of_2(n)
+
+        closest_power_of_2 = 2 ** math.floor(math.log2(n))
+        return (
+            self.get_slopes_power_of_2(closest_power_of_2)
+            + self.get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
+        )
+
+    def __call__(self, x):
+        # N, T, C
+        return self.alibi[:, : x.size(1), : x.size(1)].to(x.device)
+
+
+class KVCache(nn.Module):
+    def __init__(
+        self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16
+    ):
+        super().__init__()
+        cache_shape = (max_batch_size, max_seq_length, n_heads * head_dim)
+        self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
+        self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
+
+    def update(self, input_pos, k_val, v_val):
+        assert input_pos is not None, "input_pos should not be None"
+
+        k_out = self.k_cache
+        v_out = self.v_cache
+        k_out[:, input_pos] = k_val
+        v_out[:, input_pos] = v_val
+
+        return k_out, v_out
+
+
+class MultiheadAttention(nn.Module):
+    def __init__(self, d_model, nhead, dropout=0.1):
+        super().__init__()
+        assert d_model % nhead == 0
+        self.nhead = nhead
+        self.d_model = d_model
+        self.head_dim = d_model // nhead
+
+        self.q_proj = nn.Linear(d_model, d_model)
+        self.k_proj = nn.Linear(d_model, d_model)
+        self.v_proj = nn.Linear(d_model, d_model)
+        self.out_proj = nn.Linear(d_model, d_model)
+        self.dropout = nn.Dropout(dropout)
+        self.kv_cache = None
+
+    def forward(
+        self,
+        q,
+        k,
+        v,
+        attn_mask=None,
+        key_padding_mask=None,
+        attn_bias=None,
+        return_weights=False,
+        input_pos=None,
+    ):
+        # (B, T, C)
+        batch_size = q.size(0)
+        q_length = q.size(1)
+
+        q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)
+
+        if self.kv_cache is not None:
+            k, v = self.kv_cache.update(input_pos, k, v)
+
+        k_length = k.size(1)
+
+        if attn_bias is not None:
+            assert attn_bias.size() == (
+                self.nhead,
+                q_length,
+                k_length,
+            ), f"Should be {(self.nhead, q_length, k_length)}. Got {attn_bias.size()}"
+
+            attn_bias = attn_bias.unsqueeze(0).expand(batch_size, -1, -1, -1)
+
+        if attn_mask is not None:
+            assert attn_mask.size() == (
+                q_length,
+                k_length,
+            ), f"Should be {(q_length, k_length)}. Got {attn_mask.size()}"
+            assert attn_mask.dtype == torch.bool
+            attn_mask = attn_mask.unsqueeze(0).expand(batch_size * self.nhead, -1, -1)
+
+        if key_padding_mask is not None:
+            assert key_padding_mask.size() == (
+                batch_size,
+                k_length,
+            ), f"Should be {(batch_size, k_length)}. Got {key_padding_mask.size()}"
+            assert key_padding_mask.dtype == torch.bool
+            key_padding_mask = (
+                key_padding_mask.unsqueeze(1)
+                .unsqueeze(1)
+                .expand(-1, self.nhead, -1, -1)
+            )
+            key_padding_mask = key_padding_mask.reshape(
+                batch_size * self.nhead, 1, k_length
+            )
+            if attn_mask is None:
+                attn_mask = key_padding_mask.expand(-1, q.size(1), -1)
+            else:
+                attn_mask = attn_mask.logical_or(key_padding_mask)
+
+        if (
+            return_weights is False
+            and memory_efficient_attention is not None
+            and q.device.type == "cuda"
+        ):
+            # (-> b, t,. n, d)
+            q = rearrange(q, "b t (n d) -> b t n d", n=self.nhead)
+            k = rearrange(k, "b t (n d) -> b t n d", n=self.nhead)
+            v = rearrange(v, "b t (n d) -> b t n d", n=self.nhead)
+
+            if attn_mask is not None:
+                attn_mask = rearrange(attn_mask, "(b n) q k -> b n q k", n=self.nhead)
+
+                if attn_bias is None:
+                    attn_bias = torch.zeros_like(
+                        attn_mask, dtype=q.dtype, device=q.device
+                    )
+                attn_bias = attn_bias.masked_fill(attn_mask, float("-inf"))
+
+            if attn_bias is not None:
+                attn_bias = attn_bias.to(q.dtype)
+
+            attn_output = memory_efficient_attention(
+                q,
+                k,
+                v,
+                attn_bias=attn_bias,
+                scale=self.head_dim**-0.5,
+                p=self.dropout.p,
+            )
+            attn_output = rearrange(attn_output, "b t n d -> b t (n d)", n=self.nhead)
+
+            returned_weights = None
+        else:
+            q = rearrange(q, "b t (n d) -> (b n) t d", n=self.nhead)
+            k = rearrange(k, "b t (n d) -> (b n) t d", n=self.nhead)
+            v = rearrange(v, "b t (n d) -> (b n) t d", n=self.nhead)
+
+            attn_weights = torch.bmm(q, k.mT) * (self.head_dim**-0.5)
+            assert attn_weights.size() == (
+                batch_size * self.nhead,
+                q.size(1),
+                k.size(1),
+            )
+
+            if attn_bias is not None:
+                attn_bias = rearrange(attn_bias, "b n q k -> (b n) q k")
+                attn_weights = attn_weights + attn_bias
+
+            if attn_mask is not None:
+                attn_weights = attn_weights.masked_fill(attn_mask, float("-inf"))
+
+            attn_weights = F.softmax(attn_weights, dim=-1, dtype=attn_weights.dtype)
+            returned_weights = attn_weights.view(
+                batch_size, self.nhead, q.size(1), k.size(1)
+            )
+
+            attn_probs = self.dropout(attn_weights)
+            attn_output = torch.bmm(attn_probs, v)
+            attn_output = rearrange(attn_output, "(b n) t d -> b t (n d)", n=self.nhead)
+
+        attn_output = self.out_proj(attn_output)
+        return attn_output, returned_weights
+
+
+class GluMLP(nn.Module):
+    def __init__(self, hidden_size=1024, intermediate_size=None, activation=nn.SiLU):
+        super().__init__()
+
+        if intermediate_size is None:
+            intermediate_size = hidden_size * (11 / 3)
+            intermediate_size = round(intermediate_size / 8) * 8
+
+        self.hidden_size = hidden_size
+        self.intermediate_size = intermediate_size
+
+        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+        self.act_fn = activation()
+
+    def forward(self, x):
+        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+
+class RMSNorm(nn.Module):
+    def __init__(self, hidden_size, eps=1e-6):
+        """
+        RMSNorm is equivalent to T5LayerNorm
+        """
+        super().__init__()
+
+        self.weight = nn.Parameter(torch.ones(hidden_size))
+        self.variance_epsilon = eps
+
+    def forward(self, hidden_states):
+        input_dtype = hidden_states.dtype
+        hidden_states = hidden_states.to(torch.float32)
+        variance = hidden_states.pow(2).mean(-1, keepdim=True)
+        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+        return self.weight * hidden_states.to(input_dtype)
+
+
+class CrossAttentionLayer(nn.Module):
+    def __init__(self, hidden_size=1024, intermediate_size=None, dropout=0.1):
+        super().__init__()
+
+        self.attn = MultiheadAttention(hidden_size, 1, dropout=dropout)
+        self.mlp = GluMLP(hidden_size=hidden_size, intermediate_size=intermediate_size)
+        self.input_layernorm_q = RMSNorm(hidden_size, eps=1e-6)
+        self.input_layernorm_kv = RMSNorm(hidden_size, eps=1e-6)
+        self.post_attention_layernorm = RMSNorm(hidden_size, eps=1e-6)
+
+    def forward(
+        self,
+        tgt,
+        memory,
+        memory_key_padding_mask=None,
+        input_pos=None,
+    ):
+        residual = tgt
+        tgt, memory = self.input_layernorm_q(tgt), self.input_layernorm_kv(memory)
+        x, attn_weights = self.attn(
+            tgt,
+            memory,
+            memory,
+            key_padding_mask=memory_key_padding_mask,
+            return_weights=True,
+            input_pos=input_pos,
+        )
+        residual = x + residual
+
+        x = self.post_attention_layernorm(residual)
+        x = self.mlp(x)
+        x = x + residual
+
+        return x, attn_weights
+
+
+class TransformerEncoderLayer(nn.Module):
+    def __init__(self, hidden_size=1024, intermediate_size=None, nhead=16, dropout=0.1):
+        super().__init__()
+
+        self.attn = MultiheadAttention(hidden_size, nhead, dropout=dropout)
+        self.mlp = GluMLP(hidden_size=hidden_size, intermediate_size=intermediate_size)
+        self.input_layernorm = RMSNorm(hidden_size, eps=1e-6)
+        self.post_attention_layernorm = RMSNorm(hidden_size, eps=1e-6)
+
+    def forward(
+        self, x, attn_bias=None, key_padding_mask=None, tgt_mask=None, input_pos=None
+    ):
+        residual = x
+        x = self.input_layernorm(x)
+        x, _ = self.attn(
+            x,
+            x,
+            x,
+            attn_bias=attn_bias,
+            key_padding_mask=key_padding_mask,
+            attn_mask=tgt_mask,
+            return_weights=False,
+            input_pos=input_pos,
+        )
+        residual = x + residual
+
+        x = self.post_attention_layernorm(residual)
+        x = self.mlp(x)
+        x = x + residual
+
+        return x
+
+
+class FishSpeechTransformer(nn.Module):
+    def __init__(
+        self,
+        vocab_size,
+        codebook_size,
+        num_codebooks,
+        hidden_size=1024,
+        intermediate_size=None,
+        nhead=16,
+        num_encoder_layers=12,
+        num_decoder_layers=12,
+        dropout=0.1,
+        alignment_position=-2,
+        max_position=8192,
+    ):
+        super().__init__()
+
+        self.encoder_embedding = nn.Embedding(vocab_size, hidden_size)
+        self.decoder_embeddings = nn.ModuleList(
+            [nn.Embedding(codebook_size, hidden_size) for _ in range(num_codebooks)]
+        )
+        self.decoder_head = nn.Linear(hidden_size, codebook_size * num_codebooks)
+        self.codebook_size = codebook_size
+        self.num_codebooks = num_codebooks
+
+        self.encoder = nn.ModuleList(
+            [
+                TransformerEncoderLayer(
+                    hidden_size=hidden_size,
+                    intermediate_size=intermediate_size,
+                    nhead=nhead,
+                    dropout=dropout,
+                )
+                for _ in range(num_encoder_layers)
+            ]
+        )
+
+        self.alignment = CrossAttentionLayer(
+            hidden_size=hidden_size,
+            intermediate_size=intermediate_size,
+            dropout=dropout,
+        )
+
+        if alignment_position < 0:
+            alignment_position = num_decoder_layers + alignment_position
+
+        self.alignment_position = alignment_position
+        assert 0 <= alignment_position < num_decoder_layers
+
+        self.decoder = nn.ModuleList(
+            [
+                TransformerEncoderLayer(
+                    hidden_size=hidden_size,
+                    intermediate_size=intermediate_size,
+                    nhead=nhead,
+                    dropout=dropout,
+                )
+                for _ in range(num_decoder_layers)
+            ]
+        )
+
+        self.alibi = AlibiPostionEmbedding(nhead, max_position)
+        self.register_buffer(
+            "causual_mask",
+            torch.triu(torch.ones(max_position, max_position), diagonal=1).bool(),
+        )
+
+        self.max_batch_size = -1
+        self.max_seq_length = -1
+
+    def setup_kv_caches(self, max_batch_size, max_seq_length):
+        if (
+            self.max_seq_length >= max_seq_length
+            and self.max_batch_size >= max_batch_size
+        ):
+            return
+
+        if max_seq_length % 8 != 0:
+            max_seq_length = max_seq_length + (8 - max_seq_length % 8)
+
+        self.max_seq_length = max_seq_length
+        self.max_batch_size = max_batch_size
+
+        for b in self.decoder:
+            b.attn.kv_cache = KVCache(
+                max_batch_size, max_seq_length, b.attn.nhead, b.attn.head_dim
+            )
+
+    def forward(self, inputs, codes, input_mask=None, codes_mask=None):
+        # x: (B, T)
+        # y: (B, C, T)
+        inputs = self.encoder_embedding(inputs)
+        codes = rearrange(codes, "b c t -> c b t")
+        codes = torch.stack(
+            [emb(code) for emb, code in zip(self.decoder_embeddings, codes)], dim=0
+        )
+        codes = torch.mean(codes, dim=0)  # (B, T)
+
+        attn_bias = self.alibi(inputs)
+        for layer in self.encoder:
+            inputs = layer(inputs, attn_bias=attn_bias, key_padding_mask=input_mask)
+
+        attn_bias = self.alibi(codes)
+        causual_mask = self.causual_mask[: codes.shape[1], : codes.shape[1]]
+
+        for idx, layer in enumerate(self.decoder):
+            if idx == self.alignment_position:
+                codes, _ = self.alignment(
+                    codes, inputs, memory_key_padding_mask=input_mask
+                )
+
+            codes = layer(
+                codes,
+                attn_bias=attn_bias,
+                key_padding_mask=codes_mask,
+                tgt_mask=causual_mask,
+            )
+
+        codes = self.decoder_head(codes)
+        codes = rearrange(
+            codes, "b t (c d) -> b c t d", c=self.num_codebooks, d=self.codebook_size
+        )
+
+        return codes
+
+    def sample_decoder(
+        self,
+        x: torch.Tensor,
+        context: torch.Tensor,
+        input_pos: torch.Tensor,
+        **sampling_kwargs,
+    ):
+        attn_bias = self.alibi.alibi[:, input_pos, : self.max_seq_length]
+        causual_mask = self.causual_mask[input_pos, : self.max_seq_length]
+
+        x = rearrange(x, "b c t -> c b t")
+        x = torch.stack(
+            [emb(code) for emb, code in zip(self.decoder_embeddings, x)], dim=0
+        )
+        x = torch.mean(x, dim=0)  # (B, T)
+
+        for idx, layer in enumerate(self.decoder):
+            if idx == self.alignment_position:
+                x, _ = self.alignment(x, context)
+
+            x = layer(
+                x, attn_bias=attn_bias, input_pos=input_pos, tgt_mask=causual_mask
+            )
+
+        x = self.decoder_head(x)
+        x = rearrange(
+            x, "b t (c d) -> b c t d", c=self.num_codebooks, d=self.codebook_size
+        )
+
+        # Never predict EOS or BOS for sub-codebooks
+        x[:, 1:, :2] = -float("Inf")
+
+        next_token, probs = [], []
+        for i in range(self.num_codebooks):
+            next_token_i, probs_i = self.sample(x[:, i], **sampling_kwargs)
+            next_token.append(next_token_i)
+            probs.append(probs_i)
+
+        return torch.stack(next_token, dim=0), torch.stack(probs, dim=0)
+
+    @staticmethod
+    def multinomial_sample_one_no_sync(
+        probs_sort,
+    ):  # Does multinomial sampling without a cuda synchronization
+        q = torch.empty_like(probs_sort).exponential_(1)
+        return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
+
+    @staticmethod
+    def logits_to_probs(
+        logits,
+        temperature: float = 1.0,
+        top_p: Optional[int] = None,
+        top_k: Optional[int] = None,
+    ):
+        if top_p is not None:
+            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+            cum_probs = torch.cumsum(
+                torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
+            )
+            sorted_indices_to_remove = cum_probs > top_p
+            sorted_indices_to_remove[0] = False  # keep at least one option
+            indices_to_remove = sorted_indices_to_remove.scatter(
+                dim=0, index=sorted_indices, src=sorted_indices_to_remove
+            )
+            logits = logits.masked_fill(indices_to_remove, -float("Inf"))
+
+        logits = logits / max(temperature, 1e-5)
+
+        if top_k is not None:
+            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
+            pivot = v.select(-1, -1).unsqueeze(-1)
+            logits = torch.where(logits < pivot, -float("Inf"), logits)
+
+        probs = torch.nn.functional.softmax(logits, dim=-1)
+        return probs
+
+    def sample(
+        self,
+        logits,
+        temperature: float = 1.0,
+        top_p: Optional[int] = None,
+        top_k: Optional[int] = None,
+    ):
+        probs = self.logits_to_probs(logits[0, -1], temperature, top_p, top_k)
+        idx_next = self.multinomial_sample_one_no_sync(probs)
+        return idx_next, probs
+
+    def decode_n_tokens(
+        self,
+        cur_token: torch.Tensor,
+        context: torch.Tensor,
+        input_pos: torch.Tensor,
+        num_new_tokens: int,
+        callback=lambda _: _,
+        **sampling_kwargs,
+    ):
+        new_tokens, new_probs = [], []
+        # Sliding context window
+        batch_size = 1
+        back_map = torch.zeros(
+            [batch_size, 1], device=cur_token.device, dtype=torch.long
+        )
+
+        for i in range(num_new_tokens):
+            next_token, next_prob = self.sample_decoder(
+                cur_token, context, input_pos, **sampling_kwargs
+            )
+
+            # index_map = torch.arange(6, device=cur_token.device)
+            # index_map = back_map[:, -1:] + index_map.repeat(batch_size, 1)
+            # add = torch.arange(batch_size, device=index_map.device).unsqueeze(1) #N, 1
+            # index_map = index_map + add * t_length
+
+            input_pos += 1
+            new_tokens.append(next_token.clone())
+            callback(new_tokens[-1])
+            new_probs.append(next_prob.clone())
+
+            if next_token[0, 0] == 1:
+                break
+
+            cur_token = next_token.view(1, self.num_codebooks, -1)
+
+        return new_tokens, new_probs
+
+    def compile(self):
+        self.sampler_decoder = torch.compile(
+            self.sample_decoder, mode="reduce-overhead", fullgraph=True
+        )
+
+    @torch.no_grad()
+    def inference(self, inputs, prompt=None, max_new_tokens=1024, **sampling_kwargs):
+        # inputs: (B, T)
+        # prompt: (B, C, T)
+
+        assert inputs.size(0) == 1, "Only support batch size 1 for now"
+
+        if prompt is None:
+            prompt = torch.tensor(
+                [[[0]] * self.num_codebooks], device=inputs.device, dtype=torch.long
+            )
+
+        T = prompt.size(2)
+        T_new = T + max_new_tokens
+
+        # Encode Features
+        inputs = self.encoder_embedding(inputs)
+        attn_bias = self.alibi(inputs)
+        for layer in self.encoder:
+            inputs = layer(inputs, attn_bias=attn_bias)
+
+        device, dtype = inputs.device, inputs.dtype
+
+        # Decode
+        with torch.device(inputs.device):
+            self.setup_kv_caches(max_batch_size=1, max_seq_length=T_new)
+
+        # create an empty tensor of the expected final shape and fill in the current tokens
+        empty = torch.empty(
+            (1, self.num_codebooks, T_new), dtype=torch.long, device=device
+        )
+        empty[:, :, :T] = prompt
+        seq = empty
+        input_pos = torch.arange(0, T, device=device)
+
+        # prefill
+        next_token, _ = self.sample_decoder(
+            prompt.view(1, self.num_codebooks, -1), inputs, input_pos, **sampling_kwargs
+        )
+        seq[:, :, T] = next_token
+
+        # create an empty tensor of the expected final shape and fill in the current tokens
+        input_pos = torch.tensor([T], device=device, dtype=torch.long)
+        generated_tokens, _ = self.decode_n_tokens(
+            next_token.view(1, self.num_codebooks, -1),
+            context=inputs,
+            input_pos=input_pos,
+            num_new_tokens=max_new_tokens - 1,
+            **sampling_kwargs,
+        )
+
+        generated_tokens = torch.stack(generated_tokens, dim=-1)
+        seq = seq[:, :, : T + 1 + generated_tokens.size(-1)]
+        seq[:, :, T + 1 :] = generated_tokens
+
+        return seq
+
+
+if __name__ == "__main__":
+    # mha = MultiheadAttention(512, 8, dropout=0)
+    # mha.eval()
+    # mha.cuda()
+
+    # q, k, v = torch.randn(3, 10, 16, 512)
+    # q, k, v = q.cuda(), k.cuda(), v.cuda()
+    # alibi = AlibiPostionEmbedding(8, 1024)
+
+    # mha.bfloat16()
+    # q, k, v = q.bfloat16(), k.bfloat16(), v.bfloat16()
+    # bias = alibi(q).bfloat16()
+
+    # # Causual mask
+    # attn_mask = torch.triu(torch.ones(16, 16), diagonal=1).bool().cuda()
+    # o, w = mha(q, k, v, return_weights=True, attn_bias=bias, attn_mask=attn_mask)
+
+    # print(o.size())
+    # print(w.size())
+
+    # o1, w = mha(q, k, v, return_weights=False, attn_bias=bias, attn_mask=attn_mask)
+    # print(o1.size())
+
+    # print(o[0], o1.float()[0])
+
+    # assert torch.allclose(o.float(), o1.float(), atol=1e-2, rtol=1e-2)
+    # print("ok")
+
+    # cross = CrossAttentionLayer(512, 1024, dropout=0)
+    # cross.eval()
+    # cross.cuda()
+
+    # tgt = torch.randn(3, 10, 512).cuda()
+    # memory = torch.randn(3, 20, 512).cuda()
+    # o, w = cross(tgt, memory)
+
+    # print(o.size())
+    # print(w.size())
+
+    # ten = TransformerEncoderLayer(512, 1024, 8, dropout=0)
+    # ten.eval()
+    # ten.cuda()
+
+    # tgt = torch.randn(3, 10, 512).cuda()
+    # o = ten(tgt)
+    # print(o.size())
+
+    trans = (
+        FishSpeechTransformer(
+            vocab_size=30000,
+            codebook_size=120,
+            num_codebooks=4,
+            hidden_size=1024,
+            intermediate_size=None,
+            nhead=16,
+            num_encoder_layers=12,
+            num_decoder_layers=12,
+        )
+        .bfloat16()
+        .cuda()
+    )
+    # Print n param
+    print("Total params:", sum(i.numel() for i in trans.parameters()) / 1024 / 1024)
+    inputs = torch.randint(0, 1000, (1, 16)).cuda()
+    codes = torch.randint(0, 120, (1, 4, 128)).cuda()
+    print(trans(inputs, codes).size())
+
+    r = trans.inference(inputs, max_new_tokens=1024, top_k=5, temperature=0.3)
+    print(r)

+ 500 - 0
fish_speech/models/text2semantic/quantize.py

@@ -0,0 +1,500 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+import time
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from fish_speech.models.text2semantic.llama import Transformer, find_multiple
+
+##### Quantization Primitives ######
+
+
+def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
+    # assumes symmetric quantization
+    # assumes axis == 0
+    # assumes dense memory format
+    # TODO(future): relax ^ as needed
+
+    # default setup for affine quantization of activations
+    eps = torch.finfo(torch.float32).eps
+
+    # get min and max
+    min_val, max_val = torch.aminmax(x, dim=1)
+
+    # calculate scales and zero_points based on min and max
+    # reference: https://fburl.com/code/srbiybme
+    min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
+    max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
+    device = min_val_neg.device
+
+    # reference: https://fburl.com/code/4wll53rk
+    max_val_pos = torch.max(-min_val_neg, max_val_pos)
+    scales = max_val_pos / (float(quant_max - quant_min) / 2)
+    # ensure scales is the same dtype as the original tensor
+    scales = torch.clamp(scales, min=eps).to(x.dtype)
+    zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
+
+    # quantize based on qmin/qmax/scales/zp
+    # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
+    x_div = x / scales.unsqueeze(-1)
+    x_round = torch.round(x_div)
+    x_zp = x_round + zero_points.unsqueeze(-1)
+    quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
+
+    return quant, scales, zero_points
+
+
+def get_group_qparams(w, n_bit=4, groupsize=128):
+    # needed for GPTQ with padding
+    if groupsize > w.shape[-1]:
+        groupsize = w.shape[-1]
+    assert groupsize > 1
+    assert w.shape[-1] % groupsize == 0
+    assert w.dim() == 2
+
+    to_quant = w.reshape(-1, groupsize)
+    assert torch.isnan(to_quant).sum() == 0
+
+    max_val = to_quant.amax(dim=1, keepdim=True)
+    min_val = to_quant.amin(dim=1, keepdim=True)
+    max_int = 2**n_bit - 1
+    scales = (max_val - min_val).clamp(min=1e-6) / max_int
+    zeros = min_val + scales * (2 ** (n_bit - 1))
+    return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
+        torch.bfloat16
+    ).reshape(w.shape[0], -1)
+
+
+def pack_scales_and_zeros(scales, zeros):
+    assert scales.shape == zeros.shape
+    assert scales.dtype == torch.bfloat16
+    assert zeros.dtype == torch.bfloat16
+    return (
+        torch.cat(
+            [
+                scales.reshape(scales.size(0), scales.size(1), 1),
+                zeros.reshape(zeros.size(0), zeros.size(1), 1),
+            ],
+            2,
+        )
+        .transpose(0, 1)
+        .contiguous()
+    )
+
+
+def unpack_scales_and_zeros(scales_and_zeros):
+    assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
+    assert scales_and_zeros.dtype == torch.float
+    return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
+
+
+def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
+    assert groupsize > 1
+    # needed for GPTQ single column quantize
+    if groupsize > w.shape[-1] and scales.shape[-1] == 1:
+        groupsize = w.shape[-1]
+
+    assert w.shape[-1] % groupsize == 0
+    assert w.dim() == 2
+
+    to_quant = w.reshape(-1, groupsize)
+    assert torch.isnan(to_quant).sum() == 0
+
+    scales = scales.reshape(-1, 1)
+    zeros = zeros.reshape(-1, 1)
+    min_val = zeros - scales * (2 ** (n_bit - 1))
+    max_int = 2**n_bit - 1
+    min_int = 0
+    w_int32 = (
+        to_quant.sub(min_val)
+        .div(scales)
+        .round()
+        .clamp_(min_int, max_int)
+        .to(torch.int32)
+        .reshape_as(w)
+    )
+
+    return w_int32
+
+
+def group_quantize_tensor(w, n_bit=4, groupsize=128):
+    scales, zeros = get_group_qparams(w, n_bit, groupsize)
+    w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
+    scales_and_zeros = pack_scales_and_zeros(scales, zeros)
+    return w_int32, scales_and_zeros
+
+
+def group_dequantize_tensor_from_qparams(
+    w_int32, scales, zeros, n_bit=4, groupsize=128
+):
+    assert groupsize > 1
+    # needed for GPTQ single column dequantize
+    if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
+        groupsize = w_int32.shape[-1]
+    assert w_int32.shape[-1] % groupsize == 0
+    assert w_int32.dim() == 2
+
+    w_int32_grouped = w_int32.reshape(-1, groupsize)
+    scales = scales.reshape(-1, 1)
+    zeros = zeros.reshape(-1, 1)
+
+    w_dq = (
+        w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
+    )
+    return w_dq
+
+
+def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
+    scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
+    return group_dequantize_tensor_from_qparams(
+        w_int32, scales, zeros, n_bit, groupsize
+    )
+
+
+class QuantHandler:
+    def __init__(self, mod):
+        self.mod = mod
+
+    def create_quantized_state_dict(self) -> "StateDict":
+        pass
+
+    def convert_for_runtime(self) -> "nn.Module":
+        pass
+
+
+##### Weight-only int8 per-channel quantized code ######
+
+
+def replace_linear_weight_only_int8_per_channel(module):
+    for name, child in module.named_children():
+        if isinstance(child, nn.Linear):
+            setattr(
+                module,
+                name,
+                WeightOnlyInt8Linear(child.in_features, child.out_features),
+            )
+        else:
+            replace_linear_weight_only_int8_per_channel(child)
+
+
+class WeightOnlyInt8QuantHandler:
+    def __init__(self, mod):
+        self.mod = mod
+
+    @torch.no_grad()
+    def create_quantized_state_dict(self):
+        cur_state_dict = self.mod.state_dict()
+        for fqn, mod in self.mod.named_modules():
+            if isinstance(mod, torch.nn.Linear):
+                int8_weight, scales, _ = dynamically_quantize_per_channel(
+                    mod.weight.float(), -128, 127, torch.int8
+                )
+                cur_state_dict[f"{fqn}.weight"] = int8_weight
+                cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
+
+        return cur_state_dict
+
+    def convert_for_runtime(self):
+        replace_linear_weight_only_int8_per_channel(self.mod)
+        return self.mod
+
+
+class WeightOnlyInt8Linear(torch.nn.Module):
+    __constants__ = ["in_features", "out_features"]
+    in_features: int
+    out_features: int
+    weight: torch.Tensor
+
+    def __init__(
+        self,
+        in_features: int,
+        out_features: int,
+        bias: bool = True,
+        device=None,
+        dtype=None,
+    ) -> None:
+        factory_kwargs = {"device": device, "dtype": dtype}
+        super().__init__()
+        self.in_features = in_features
+        self.out_features = out_features
+        self.register_buffer(
+            "weight", torch.empty((out_features, in_features), dtype=torch.int8)
+        )
+        self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
+
+    def forward(self, input: torch.Tensor) -> torch.Tensor:
+        return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
+
+
+##### weight only int4 per channel groupwise quantized code ######
+
+
+def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
+    weight_int32, scales_and_zeros = group_quantize_tensor(
+        weight_bf16, n_bit=4, groupsize=groupsize
+    )
+    weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
+        weight_int32, inner_k_tiles
+    )
+    return weight_int4pack, scales_and_zeros
+
+
+def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
+    origin_x_size = x.size()
+    x = x.reshape(-1, origin_x_size[-1])
+    c = torch.ops.aten._weight_int4pack_mm(
+        x, weight_int4pack, groupsize, scales_and_zeros
+    )
+    new_shape = origin_x_size[:-1] + (out_features,)
+    c = c.reshape(new_shape)
+    return c
+
+
+def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1):
+    return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
+
+
+def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
+    for name, child in module.named_children():
+        if isinstance(child, nn.Linear):
+            if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
+                setattr(
+                    module,
+                    name,
+                    WeightOnlyInt4Linear(
+                        child.in_features,
+                        child.out_features,
+                        bias=False,
+                        groupsize=groupsize,
+                        inner_k_tiles=inner_k_tiles,
+                        padding=False,
+                    ),
+                )
+            elif padding:
+                setattr(
+                    module,
+                    name,
+                    WeightOnlyInt4Linear(
+                        child.in_features,
+                        child.out_features,
+                        bias=False,
+                        groupsize=groupsize,
+                        inner_k_tiles=inner_k_tiles,
+                        padding=True,
+                    ),
+                )
+        else:
+            replace_linear_int4(child, groupsize, inner_k_tiles, padding)
+
+
+class WeightOnlyInt4QuantHandler:
+    def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
+        self.mod = mod
+        self.groupsize = groupsize
+        self.inner_k_tiles = inner_k_tiles
+        self.padding = padding
+        assert groupsize in [32, 64, 128, 256]
+        assert inner_k_tiles in [2, 4, 8]
+
+    @torch.no_grad()
+    def create_quantized_state_dict(self):
+        cur_state_dict = self.mod.state_dict()
+        for fqn, mod in self.mod.named_modules():
+            if isinstance(mod, torch.nn.Linear):
+                assert not mod.bias
+                out_features = mod.out_features
+                in_features = mod.in_features
+                assert out_features % 8 == 0, "require out_features % 8 == 0"
+                print(f"linear: {fqn}, in={in_features}, out={out_features}")
+
+                weight = mod.weight.data
+                if not _check_linear_int4_k(
+                    in_features, self.groupsize, self.inner_k_tiles
+                ):
+                    if self.padding:
+                        import torch.nn.functional as F
+
+                        print(
+                            f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
+                        )
+                        padded_in_features = find_multiple(in_features, 1024)
+                        weight = F.pad(
+                            weight, pad=(0, padded_in_features - in_features)
+                        )
+                    else:
+                        print(
+                            f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
+                            + "and that groupsize and inner_k_tiles*16 evenly divide into it"
+                        )
+                        continue
+                (
+                    weight_int4pack,
+                    scales_and_zeros,
+                ) = prepare_int4_weight_and_scales_and_zeros(
+                    weight.to(torch.bfloat16).to("cuda"),
+                    self.groupsize,
+                    self.inner_k_tiles,
+                )
+                cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
+                cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu")
+
+        return cur_state_dict
+
+    def convert_for_runtime(self):
+        replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
+        return self.mod
+
+
+class WeightOnlyInt4Linear(torch.nn.Module):
+    __constants__ = ["in_features", "out_features"]
+    in_features: int
+    out_features: int
+    weight: torch.Tensor
+
+    def __init__(
+        self,
+        in_features: int,
+        out_features: int,
+        bias=True,
+        device=None,
+        dtype=None,
+        groupsize: int = 128,
+        inner_k_tiles: int = 8,
+        padding: bool = True,
+    ) -> None:
+        super().__init__()
+        self.padding = padding
+        if padding:
+            self.origin_in_features = in_features
+            in_features = find_multiple(in_features, 1024)
+
+        self.in_features = in_features
+        self.out_features = out_features
+        assert not bias, "require bias=False"
+        self.groupsize = groupsize
+        self.inner_k_tiles = inner_k_tiles
+
+        assert out_features % 8 == 0, "require out_features % 8 == 0"
+        assert (
+            in_features % (inner_k_tiles * 16) == 0
+        ), "require in_features % (innerKTiles * 16) == 0"
+        self.register_buffer(
+            "weight",
+            torch.empty(
+                (
+                    out_features // 8,
+                    in_features // (inner_k_tiles * 16),
+                    32,
+                    inner_k_tiles // 2,
+                ),
+                dtype=torch.int32,
+            ),
+        )
+        self.register_buffer(
+            "scales_and_zeros",
+            torch.empty(
+                (in_features // groupsize, out_features, 2), dtype=torch.bfloat16
+            ),
+        )
+
+    def forward(self, input: torch.Tensor) -> torch.Tensor:
+        input = input.to(torch.bfloat16)
+        if self.padding:
+            import torch.nn.functional as F
+
+            input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
+        return linear_forward_int4(
+            input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize
+        )
+
+
+def quantize(
+    checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"),
+    mode: str = "int8",
+    # following arguments only available when setting int4 quantization.
+    groupsize: int = 128,
+    label: str = "",
+) -> None:
+    assert checkpoint_path.is_file(), checkpoint_path
+
+    device = "cpu"
+    precision = torch.bfloat16
+
+    print("Loading model ...")
+    t0 = time.time()
+
+    with torch.device("meta"):
+        model = Transformer.from_name(checkpoint_path.parent.name)
+
+    checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
+    model.load_state_dict(checkpoint, assign=True)
+    model = model.to(dtype=precision, device=device)
+
+    if mode == "int8":
+        print(
+            "Quantizing model weights for int8 weight-only symmetric per-channel quantization"
+        )
+        quant_handler = WeightOnlyInt8QuantHandler(model)
+        quantized_state_dict = quant_handler.create_quantized_state_dict()
+
+        dir_name = checkpoint_path.parent
+        base_name = checkpoint_path.name
+        new_base_name = base_name.replace(".pth", f"{label}int8.pth")
+
+    elif mode == "int4":
+        print(
+            "Quantizing model weights for int4 weight-only affine per-channel groupwise quantization"
+        )
+        quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
+        quantized_state_dict = quant_handler.create_quantized_state_dict()
+
+        dir_name = checkpoint_path.parent
+        base_name = checkpoint_path.name
+        new_base_name = base_name.replace(".pth", f"{label}int4.g{groupsize}.pth")
+
+    else:
+        raise ValueError(
+            f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]"
+        )
+
+    quantize_path = dir_name / new_base_name
+    print(f"Writing quantized weights to {quantize_path}")
+    quantize_path.unlink(missing_ok=True)  # remove existing file if one already there
+    torch.save(quantized_state_dict, quantize_path)
+    print(f"Quantization complete took {time.time() - t0:.02f} seconds")
+    return
+
+
+if __name__ == "__main__":
+    import argparse
+
+    parser = argparse.ArgumentParser(description="Quantize a model.")
+    parser.add_argument(
+        "--checkpoint_path",
+        type=Path,
+        default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"),
+        help="Path to the model checkpoint to be quantized.",
+    )
+    parser.add_argument(
+        "--mode",
+        "-q",
+        type=str,
+        default="int8",
+        choices=["int8", "int4"],
+        help="type of quantization to perform",
+    )
+    parser.add_argument(
+        "--groupsize", type=int, default=32, help="Group size for int4 quantization."
+    )
+    parser.add_argument(
+        "--label", type=str, default="_", help="label to add to output filename"
+    )
+
+    args = parser.parse_args()
+    quantize(args.checkpoint_path, args.mode, args.groupsize, args.label)

+ 170 - 0
fish_speech/models/text2semantic/tp.py

@@ -0,0 +1,170 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+from typing import List, Optional
+
+import torch
+import torch.distributed as dist
+from quantize import WeightOnlyInt4Linear
+from torch import nn
+from torch.distributed import _functional_collectives as funcol
+
+from fish_speech.models.text2semantic.llama import Attention, FeedForward, Transformer
+
+
+def _get_rank() -> int:
+    return int(os.environ.get("LOCAL_RANK", "0"))
+
+
+def is_local():
+    return _get_rank() == 0
+
+
+def local_break():
+    if is_local():
+        breakpoint()
+    dist.barrier()
+
+
+def _get_world_size() -> int:
+    return int(os.environ.get("LOCAL_WORLD_SIZE", "1"))
+
+
+def maybe_init_dist() -> Optional[int]:
+    try:
+        # provided by torchrun
+        rank = _get_rank()
+        world_size = _get_world_size()
+
+        if world_size < 2:
+            # too few gpus to parallelize, tp is no-op
+            return None
+    except KeyError:
+        # not run via torchrun, no-op
+        return None
+
+    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
+    return rank
+
+
+def _apply_tp_linear(
+    linear: nn.Linear, style: str, weight_splits: List[int] = []
+) -> None:
+    rank = _get_rank()
+    world_size = _get_world_size()
+
+    # Linear's weight matrix is transposed, and is of shape
+    # (linear.out_features, linear.in_features)
+    dim_lookup = {"colwise": (0, "out_features"), "rowwise": (1, "in_features")}
+    assert style in dim_lookup
+    shard_dim, size_attr = dim_lookup[style]
+
+    # ensure we can shard evenly
+    assert getattr(linear, size_attr) % world_size == 0
+
+    def shard(x, dim):
+        assert x.size(dim=dim) % world_size == 0
+        return torch.tensor_split(x, world_size, dim=dim)[rank]
+
+    def shard_qkv(qkv, dim, weight_splits):
+        q, k, v = qkv.split(weight_splits, dim=dim)
+        q = shard(q, dim)
+        k = shard(k, dim)
+        v = shard(v, dim)
+        return torch.cat((q, k, v), dim=dim)
+
+    # shard
+    if weight_splits:
+        # attention
+        assert len(weight_splits) == 3
+
+        if isinstance(linear, WeightOnlyInt4Linear):
+            sharded_weight = shard_qkv(
+                linear.weight, shard_dim, [i // 8 for i in weight_splits]
+            )
+            linear.scales_and_zeros = shard_qkv(
+                linear.scales_and_zeros, 1 - shard_dim, weight_splits
+            )
+        else:
+            sharded_weight = shard_qkv(linear.weight, shard_dim, weight_splits)
+        if hasattr(linear, "scales") and style == "colwise":
+            linear.scales = shard_qkv(linear.scales, 0, weight_splits)
+    else:
+        sharded_weight = shard(linear.weight, shard_dim)
+        if isinstance(linear, WeightOnlyInt4Linear):
+            linear.scales_and_zeros = shard(linear.scales_and_zeros, 1 - shard_dim)
+            if style == "rowwise":
+                assert (
+                    linear.scales_and_zeros.shape[0] * 32
+                    == sharded_weight.shape[1]
+                    * sharded_weight.shape[2]
+                    * sharded_weight.shape[3]
+                )
+                assert linear.scales_and_zeros.shape[1] == sharded_weight.shape[0] * 8
+        if hasattr(linear, "scales") and style == "colwise":
+            linear.scales = shard(linear.scales, 0)
+
+    # local_break()
+    linear.weight = nn.Parameter(sharded_weight, requires_grad=False)
+    setattr(linear, size_attr, getattr(linear, size_attr) // world_size)
+
+    # shape info should still be synced
+    # assert linear.weight.shape == (linear.out_features, linear.in_features)
+
+
+def _apply_tp_ffn(mlp: FeedForward) -> None:
+    assert hasattr(mlp, "w1")
+    assert hasattr(mlp, "w3")
+    assert hasattr(mlp, "w2")
+
+    _apply_tp_linear(mlp.w1, "colwise")
+    _apply_tp_linear(mlp.w3, "colwise")
+    _apply_tp_linear(mlp.w2, "rowwise")
+
+    world_size = _get_world_size()
+    mlp.register_forward_hook(
+        lambda _module, _input, output: funcol.all_reduce(
+            output, "sum", list(range(world_size))
+        )
+    )
+
+
+def _apply_tp_attn(attn: Attention) -> None:
+    assert hasattr(attn, "wqkv")
+    assert hasattr(attn, "wo")
+
+    kv_size = attn.n_local_heads * attn.head_dim
+    _apply_tp_linear(attn.wqkv, "colwise", [attn.dim, kv_size, kv_size])
+    _apply_tp_linear(attn.wo, "rowwise")
+
+    # overwrite
+    world_size = _get_world_size()
+    attn.n_head = attn.n_head // world_size
+    attn.dim = attn.dim // world_size
+    attn.head_dim = attn.dim // attn.n_head
+    attn.n_local_heads = attn.n_local_heads // world_size
+
+    attn.register_forward_hook(
+        lambda _module, _input, output: funcol.all_reduce(
+            output[0], "sum", list(range(world_size))
+        )
+    )
+
+
+def _apply_tp_Transformer(Transformer: Transformer) -> None:
+    # overwrite config before Transformer.setup_cache is called
+    world_size = _get_world_size()
+    Transformer.config.n_head = Transformer.config.n_head // world_size
+    Transformer.config.dim = Transformer.config.dim // world_size
+    Transformer.config.n_local_heads = Transformer.config.n_local_heads // world_size
+
+
+def apply_tp(model: Transformer) -> None:
+    _apply_tp_Transformer(model)
+    for block in model.layers:
+        # Apply to MLP
+        _apply_tp_ffn(block.feed_forward)
+        _apply_tp_attn(block.attention)

+ 2 - 4
tools/infer_vq.py

@@ -69,10 +69,8 @@ def main():
     print(indices.shape)
 
     # Restore
-    indices = np.load(
-        "data/LibriTTS_R/train-clean-100/7226/86964/7226_86964_000012_000003.npy"
-    )
-    indices = torch.from_numpy(indices).to(model.device)
+    indices = np.load("codes_0.npy")
+    indices = torch.from_numpy(indices).to(model.device).long()
     indices = indices.unsqueeze(1).unsqueeze(-1)
     mel_lengths = indices.shape[2] * (
         model.downsample.total_strides if model.downsample is not None else 1

+ 12 - 0
tools/llama/extract_model.py

@@ -0,0 +1,12 @@
+import torch
+
+state_dict = torch.load(
+    "results/text2semantic_400m/checkpoints/step_000025000.ckpt", map_location="cpu"
+)["state_dict"]
+state_dict = {
+    state_dict.replace("model.", ""): value
+    for state_dict, value in state_dict.items()
+    if state_dict.startswith("model.")
+}
+
+torch.save(state_dict, "results/text2semantic_400m/step_000025000_weights.ckpt")