Jelajahi Sumber

Polish inference script

Lengyue 2 tahun lalu
induk
melakukan
7a0086536c
4 mengubah file dengan 224 tambahan dan 438 penghapusan
  1. 5 2
      README.md
  2. 17 4
      README.zh.md
  3. 202 262
      tools/llama/generate.py
  4. 0 170
      tools/llama/tp.py

+ 5 - 2
README.md

@@ -10,7 +10,7 @@ This codebase is released under BSD-3-Clause License, and all models are release
 We do not hold any responsibility for any illegal usage of the codebase. Please refer to your local laws about DMCA and other related laws.
 
 ## Requirements
-- GPU memory: 4GB (for inference), 24GB (for finetuning)
+- GPU memory: 2GB (for inference), 24GB (for finetuning)
 - System: Linux (full functionality), Windows (inference only, flash-attn is not supported, torch.compile is not supported)
 
 Therefore, we strongly recommend to use WSL2 or docker to run the codebase for Windows users.
@@ -38,7 +38,10 @@ TODO
 
 Generate semantic tokens from text:
 ```bash
-python tools/llama/generate.py
+python tools/llama/generate.py \
+    --text "Hello" \
+    --num-samples 2 \
+    --compile
 ```
 
 You may want to use `--compile` to fuse cuda kernels faster inference (~25 tokens/sec -> ~300 tokens/sec).

+ 17 - 4
README.zh.md

@@ -8,7 +8,7 @@
 我们不对代码库的任何非法使用承担任何责任。请参阅您当地关于DMCA和其他相关法律的法律。
 
 ## 要求
-- GPU内存:4GB(用于推理),24GB(用于微调)
+- GPU内存:2GB(用于推理),24GB(用于微调)
 - 系统:Linux(全部功能),Windows(仅推理,不支持flash-attn,不支持torch.compile)
 
 因此,我们强烈建议Windows用户使用WSL2或docker来运行代码库。
@@ -35,14 +35,27 @@ pip3 install -e .
 TODO
 ```
 
-从文本生成语义 token
+### [可选] 从语音生成 prompt
 ```bash
-python tools/llama/generate.py
+python tools/vqgan/inference.py -i codes_0.wav
+```
+
+你应该能得到一个 `fake.npy` 文件。
+
+### 从文本生成语义 token:
+```bash
+python tools/llama/generate.py \
+    --text "要转换的文本" \
+    --prompt-string "你的参考文本" \
+    --prompt-tokens "fake.npy" \
+    --checkpoint-path results/text2semantic_400m_finetune/step_000002000.pth \
+    --num-samples 2 \
+    --compile
 ```
 
 您可能希望使用 `--compile` 来融合 cuda 内核以实现更快的推理(~25 个 token/秒 -> ~300 个 token/秒)。
 
-从语义 token 生成人声:
+### 从语义 token 生成人声:
 ```bash
 python tools/vqgan/inference.py -i codes_0.npy
 ```

+ 202 - 262
tools/llama/generate.py

@@ -1,30 +1,29 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-import itertools
-import sys
+import os
 import time
 from pathlib import Path
 from typing import Optional, Tuple
 
+import click
 import numpy as np
 import torch
 import torch._dynamo.config
 import torch._inductor.config
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from loguru import logger
+from tqdm import tqdm
 from transformers import AutoTokenizer
 
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
 torch._inductor.config.coordinate_descent_tuning = True
 torch._inductor.config.triton.unique_kernel_names = True
 torch._inductor.config.fx_graph_cache = True  # Experimental feature to reduce compilation times, will be on by default in future
 
 
-from fish_speech.models.text2semantic.llama import ModelArgs, Transformer
+from fish_speech.models.text2semantic.llama import Transformer
 from fish_speech.text import g2p
 from fish_speech.text.symbols import pad as pad_symbol
 from fish_speech.text.symbols import pu_symbols
-from tools.llama.tp import maybe_init_dist
 
 
 def multinomial_sample_one_no_sync(
@@ -68,6 +67,7 @@ def logits_to_probs(
         v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
         pivot = v.select(-1, -1).unsqueeze(-1)
         logits = torch.where(logits < pivot, -float("Inf"), logits)
+
     probs = torch.nn.functional.softmax(logits, dim=-1)
     return probs
 
@@ -75,31 +75,58 @@ def logits_to_probs(
 def sample(
     logits,
     previous_tokens: Optional[torch.Tensor] = None,
-    temperature: float = 1.0,
-    top_k: Optional[int] = None,
-    top_p: Optional[int] = None,
-    repetition_penalty: float = 1.0,
+    **sampling_kwargs,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
     probs = logits_to_probs(
-        logits[0, -1], previous_tokens, temperature, top_k, top_p, repetition_penalty
+        logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
     )
     idx_next = multinomial_sample_one_no_sync(probs)
     return idx_next, probs
 
 
-def decode_token(
+def decode_one_token(
     model: Transformer,
     x: torch.Tensor,
     input_pos: torch.Tensor,
     previous_tokens: torch.Tensor = None,
     **sampling_kwargs,
+) -> torch.Tensor:
+    assert input_pos.shape[-1] == 1
+
+    logits = model.forward_generate(x, input_pos)
+    codebooks = [
+        sample(
+            logits.token_logits,
+            previous_tokens=previous_tokens[0],
+            **sampling_kwargs,
+        )[0]
+    ]
+
+    # Disable <s> and </s> tokens for codebooks
+    if model.config.num_codebooks != 0:
+        logits.codebook_logits[:, :, :, :2] = -float("Inf")
+
+        for i in range(model.config.num_codebooks):
+            codebooks.append(
+                sample(
+                    logits.codebook_logits[:, :, i],
+                    previous_tokens=previous_tokens[i + 1],
+                    **sampling_kwargs,
+                )[0]
+            )
+
+    return torch.stack(codebooks, dim=0)
+
+
+def prefill(
+    model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
 ) -> torch.Tensor:
     # input_pos: [B, S]
     logits = model.forward_generate(x, input_pos)
     codebooks = [
         sample(
             logits.token_logits,
-            previous_tokens=previous_tokens[0] if previous_tokens is not None else None,
+            previous_tokens=None,
             **sampling_kwargs,
         )[0]
     ]
@@ -112,9 +139,7 @@ def decode_token(
             codebooks.append(
                 sample(
                     logits.codebook_logits[:, :, i],
-                    previous_tokens=previous_tokens[i]
-                    if previous_tokens is not None
-                    else None,
+                    previous_tokens=None,
                     **sampling_kwargs,
                 )[0]
             )
@@ -127,46 +152,47 @@ def decode_n_tokens(
     cur_token: torch.Tensor,
     input_pos: torch.Tensor,
     num_new_tokens: int,
-    callback=lambda _: _,
+    eos_token_id: int = 2,
     **sampling_kwargs,
 ):
-    new_tokens = []
-    for i in range(num_new_tokens):
+    previous_tokens = torch.zeros(
+        (model.config.num_codebooks + 1, num_new_tokens),
+        dtype=torch.int,
+        device=cur_token.device,
+    )
+
+    for i in tqdm(range(num_new_tokens)):
         with torch.backends.cuda.sdp_kernel(
             enable_flash=False, enable_mem_efficient=False, enable_math=True
         ):  # Actually better for Inductor to codegen attention here
-            next_token = decode_token(
+            next_token = decode_one_token(
                 model,
                 cur_token,
                 input_pos,
-                torch.concat(new_tokens, dim=1) if len(new_tokens) > 0 else None,
+                previous_tokens,
                 **sampling_kwargs,
             )
+
         input_pos += 1
-        new_tokens.append(next_token.clone())
-        callback(new_tokens[-1])
         cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
+        previous_tokens[:, i : i + 1] = next_token.view(
+            model.config.num_codebooks + 1, -1
+        )
 
         # TODO: use tokenizer's eos
-        if (cur_token[0, 0, -1] == 2).any():
-            print("EOS detected, stopping generation")
+        if (cur_token[0, 0, -1] == eos_token_id).any():
             break
 
-    return new_tokens
-
-
-def model_forward(model, x, input_pos):
-    return model(x, input_pos)
+    return previous_tokens[:, : i + 1]
 
 
 @torch.no_grad()
 def generate(
+    *,
     model: Transformer,
     prompt: torch.Tensor,
     max_new_tokens: int,
-    *,
-    interactive: bool,
-    callback=lambda x: x,
+    eos_token_id: int = 2,
     **sampling_kwargs,
 ) -> torch.Tensor:
     """
@@ -175,20 +201,20 @@ def generate(
 
     # create an empty tensor of the expected final shape and fill in the current tokens
     T = prompt.size(1)
-    if T + max_new_tokens > model.config.max_seq_len:
-        max_new_tokens = model.config.max_seq_len - T
-        print(f"Truncating max_new_tokens to {max_new_tokens}")
 
-    T_new = T + max_new_tokens
-    if interactive:
-        max_seq_length = 350
+    if max_new_tokens:
+        if T + max_new_tokens > model.config.max_seq_len:
+            max_new_tokens = model.config.max_seq_len - T
+            logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
+
+        T_new = T + max_new_tokens
     else:
-        max_seq_length = min(T_new, model.config.max_seq_len)
+        T_new = model.config.max_seq_len
+        max_new_tokens = T_new - T
 
     device, dtype = prompt.device, prompt.dtype
-    max_seq_length = max_seq_length
     with torch.device(device):
-        model.setup_caches(max_batch_size=1, max_seq_len=max_seq_length)
+        model.setup_caches(max_batch_size=1, max_seq_len=T_new)
 
     codebook_dim = 1 + model.config.num_codebooks
     # create an empty tensor of the expected final shape and fill in the current tokens
@@ -197,44 +223,43 @@ def generate(
     seq = empty
     input_pos = torch.arange(0, T, device=device)
 
-    next_token = decode_token(
+    next_token = prefill(
         model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
     )
     seq[:, T : T + 1] = next_token
 
     input_pos = torch.tensor([T], device=device, dtype=torch.int)
-    generated_tokens = decode_n_tokens(
+    x = decode_n_tokens(
         model,
         next_token.view(1, codebook_dim, -1),
         input_pos,
         max_new_tokens - 1,
-        callback=callback,
+        eos_token_id=eos_token_id,
         **sampling_kwargs,
     )
-    x = torch.cat(generated_tokens, dim=1)
+    # x = torch.cat(generated_tokens, dim=1)
     seq = seq[:, : T + 1 + x.size(1)]
     seq[:, T + 1 :] = x
 
     return seq
 
 
-def encode_tokens(tokenizer, string, bos=True, device="cuda"):
-    # data/Genshin/Chinese/神里绫华/vo_ayaka_character_idle_04.npy
+def encode_tokens(
+    tokenizer, string, bos=True, device="cuda", prompt_string=None, prompt_tokens=None
+):
+    if prompt_string is not None:
+        string = prompt_string + " " + string
 
-    prompt = g2p("<zh>算啦,虽然他罪无可恕,但也有可怜的地方嘛。</zh> {string}")
+    prompt = g2p(string)
     prompt = [
         (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
         for _, i in prompt
     ]
     prompt = " ".join(prompt)
     string = f"[INST] {prompt} [/INST]"
-    print("Encoding string:", string)
-
-    data = np.load("data/Genshin/Chinese/派蒙/vo_WYLQ103_10_paimon_03.npy")
-    codes = [f"<s:{i}>" for i in data[0]]
 
     tokens = tokenizer.encode(
-        string + " ".join(codes),
+        string,
         max_length=10**6,
         add_special_tokens=bos,
         truncation=False,
@@ -242,47 +267,42 @@ def encode_tokens(tokenizer, string, bos=True, device="cuda"):
     tokens = torch.tensor([tokens], dtype=torch.int, device=device)
 
     # Codebooks
-    # zeros = torch.zeros((4, tokens.size(1)), dtype=torch.int, device=device)
-    # prompt = torch.cat((tokens, zeros), dim=0)
+    zeros = torch.zeros((4, tokens.size(1)), dtype=torch.int, device=device)
+    prompt = torch.cat((tokens, zeros), dim=0)
+
+    if prompt_tokens is None:
+        return prompt
 
-    # # Get prompt tokens
-    # data = np.load("data/Genshin/Chinese/神里绫华/vo_ayaka_character_idle_02.npy")
-    # data = torch.from_numpy(data).to(device=device, dtype=torch.int) + 2
+    # Get prompt tokens
+    assert prompt_tokens.ndim == 2
+    data = prompt_tokens + 2
 
-    # zeros = torch.zeros((1, data.size(1)), dtype=torch.int, device=device) + 32311 # 32311 is the <pad> token
-    # data = torch.cat((zeros, data), dim=0)
-    # prompt = torch.cat((prompt, data), dim=1)
-    # print(prompt)
+    zeros = (
+        torch.zeros((1, data.size(1)), dtype=torch.int, device=device)
+        + tokenizer.pad_token_id
+    )  # 32311 is the <pad> token
+    data = torch.cat((zeros, data), dim=0)
+    prompt = torch.cat((prompt, data), dim=1)
 
-    return tokens
+    return prompt
 
 
-def _load_model(checkpoint_path, device, precision, use_tp):
+def load_model(config_name, checkpoint_path, device, precision):
+    with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
+        cfg = compose(config_name=config_name)
+
     with torch.device("meta"):
-        # TODO: support different model archs
-        model = Transformer(
-            ModelArgs(
-                max_seq_len=4096,
-                vocab_size=36408,
-                n_layer=24,
-                n_head=16,
-                dim=1024,
-                rope_base=10000,
-                norm_eps=1e-5,
-                codebook_size=168,
-                num_codebooks=0,
-            )
-        )
+        model: Transformer = instantiate(cfg.model.model)
 
     if "int8" in str(checkpoint_path):
-        print("Using int8 weight-only quantization!")
+        logger.info("Using int8 weight-only quantization!")
         from quantize import WeightOnlyInt8QuantHandler
 
         simple_quantizer = WeightOnlyInt8QuantHandler(model)
         model = simple_quantizer.convert_for_runtime()
 
     if "int4" in str(checkpoint_path):
-        print("Using int4 quantization!")
+        logger.info("Using int4 quantization!")
         path_comps = checkpoint_path.name.split(".")
         assert path_comps[-2].startswith("g")
         groupsize = int(path_comps[-2][1:])
@@ -291,215 +311,135 @@ def _load_model(checkpoint_path, device, precision, use_tp):
         simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
         model = simple_quantizer.convert_for_runtime()
 
-    checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
-    model.load_state_dict(checkpoint, assign=True)
+    checkpoint = torch.load(str(checkpoint_path), map_location="cpu")
+    if "state_dict" in checkpoint:
+        checkpoint = checkpoint["state_dict"]
 
-    if use_tp:
-        from tp import apply_tp
+    if any(k.startswith("model.") for k in checkpoint):
+        checkpoint = {
+            k.replace("model.", ""): v
+            for k, v in checkpoint.items()
+            if k.startswith("model.")
+        }
 
-        print("Applying tensor parallel to model ...")
-        apply_tp(model)
+    model.load_state_dict(checkpoint, assign=True)
 
     model = model.to(device=device, dtype=precision)
-    return model.eval()
-
+    logger.info("Restored model from checkpoint")
 
-B_INST, E_INST = "[INST]", "[/INST]"
+    return model.eval()
 
 
+@click.command()
+@click.option("--text", type=str, default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.")
+@click.option("--prompt-string", type=str, default=None)
+@click.option(
+    "--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
+)
+@click.option("--num-samples", type=int, default=1)
+@click.option("--max_new_tokens", type=int, default=0)
+@click.option("--top_k", type=int, default=50)
+@click.option("--top_p", type=float, default=0.95)
+@click.option("--repetition-penalty", type=float, default=1.05)
+@click.option("--temperature", type=float, default=0.8)
+@click.option(
+    "--checkpoint-path",
+    type=click.Path(path_type=Path, exists=True),
+    default="results/text2semantic_400m_finetune/step_000002000.pth",
+)
+@click.option("--config-name", type=str, default="text2semantic_finetune")
+@click.option("--tokenizer", type=str, default="fishaudio/speech-lm-v1")
+@click.option("--compile/--no-compile", default=False)
+@click.option("--seed", type=int, default=42)
 def main(
-    prompt: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
-    interactive: bool = False,
-    num_samples: int = 5,
-    max_new_tokens: int = 100,
-    top_k: int = None,
-    top_p: int = 1.0,
-    repetition_penalty: float = 1.0,
-    temperature: float = 0.8,
-    checkpoint_path: Path = Path(
-        "results/text2semantic_400m/checkpoints/step_000025000.ckpt"
-    ),
-    compile: bool = True,
-    profile: Optional[Path] = None,
-    tokenizer: str = "fishaudio/speech-lm-v1",
+    text: str,
+    prompt_string: Optional[str],
+    prompt_tokens: Optional[Path],
+    num_samples: int,
+    max_new_tokens: int,
+    top_k: int,
+    top_p: int,
+    repetition_penalty: float,
+    temperature: float,
+    checkpoint_path: Path,
+    config_name: str,
+    tokenizer: str,
+    compile: bool,
+    seed: int,
 ) -> None:
-    """Generates text samples based on a pre-trained Transformer model and tokenizer."""
-    assert checkpoint_path.is_file(), checkpoint_path
-
-    global print
-    rank = maybe_init_dist()
-    use_tp = rank is not None
-    if use_tp:
-        torch.cuda.set_device(rank)
-        if rank != 0:
-            # only print on rank 0
-            print = lambda *args, **kwargs: None
-
     device = "cuda"
     precision = torch.bfloat16
 
-    print("Loading model ...")
+    logger.info("Loading model ...")
     t0 = time.time()
-    model = _load_model(checkpoint_path, device, precision, use_tp)
+    model = load_model(config_name, checkpoint_path, device, precision)
+    model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
 
     torch.cuda.synchronize()
-    print(f"Time to load model: {time.time() - t0:.02f} seconds")
+    logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
 
     tokenizer = AutoTokenizer.from_pretrained(tokenizer)
-    print(prompt)
-    encoded = encode_tokens(tokenizer, f"{prompt}", bos=True, device=device)
+    prompt_tokens = (
+        torch.from_numpy(np.load(prompt_tokens)).to(device)
+        if prompt_tokens is not None
+        else None
+    )
+    encoded = encode_tokens(
+        tokenizer,
+        text,
+        prompt_string=prompt_string,
+        prompt_tokens=prompt_tokens,
+        bos=True,
+        device=device,
+    )
     prompt_length = encoded.size(1)
+    logger.info(f"Encoded prompt shape: {encoded.shape}")
 
-    torch.manual_seed(1234)
-    model_size = sum(
-        [
-            p.numel() * p.dtype.itemsize
-            for p in itertools.chain(model.parameters(), model.buffers())
-        ]
-    )
+    torch.manual_seed(seed)
     if compile:
-        global decode_token
-        decode_token = torch.compile(
-            decode_token, mode="reduce-overhead", fullgraph=True
+        global decode_one_token
+        decode_one_token = torch.compile(
+            decode_one_token, mode="reduce-overhead", fullgraph=True
         )
 
-    aggregate_metrics = {
-        "tokens_per_sec": [],
-    }
-    start = -1 if compile else 0
-
-    for i in range(start, num_samples):
+    for i in range(num_samples):
         torch.cuda.synchronize()
-        if i >= 0 and interactive:
-            prompt = input("What is your prompt? ")
-            encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
-
-        if interactive and i >= 0:
-            buffer = []
-            period_id = tokenizer.encode(".")[0]
-            done_generating = False
-
-            def callback(x):
-                nonlocal done_generating
-                if done_generating:
-                    return
-                buffer.append(tokenizer.decode([period_id] + x.tolist())[1:])
-                if x.item() == tokenizer.eos_id():
-                    done_generating = True
-                if len(buffer) == 4 or done_generating:
-                    print("".join(buffer), end="", flush=True)
-                    buffer.clear()
-                # print(, end='', flush=True)
-
-        else:
-            callback = lambda x: x
+
         t0 = time.perf_counter()
-        import contextlib
-
-        if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
-            prof = contextlib.nullcontext()
-        else:
-            torch.profiler._utils._init_for_cuda_graphs()
-            prof = torch.profiler.profile()
-        with prof:
-            y = generate(
-                model,
-                encoded,
-                max_new_tokens,
-                interactive=interactive,
-                callback=callback,
-                temperature=temperature,
-                top_k=top_k,
-                top_p=top_p,
-                repetition_penalty=repetition_penalty,
-            )
-        if i == -1:
-            print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
-            continue
-        if hasattr(prof, "export_chrome_trace"):
-            if use_tp:
-                prof.export_chrome_trace(f"{profile}_rank_{rank}.json")
-            else:
-                prof.export_chrome_trace(f"{profile}.json")
+        y = generate(
+            model=model,
+            prompt=encoded,
+            max_new_tokens=max_new_tokens,
+            eos_token_id=tokenizer.eos_token_id,
+            temperature=temperature,
+            top_k=top_k,
+            top_p=top_p,
+            repetition_penalty=repetition_penalty,
+        )
+
+        if i == 0 and compile:
+            logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
+
         torch.cuda.synchronize()
         t = time.perf_counter() - t0
 
-        if not interactive:
-            print(tokenizer.decode(y[0, :prompt_length:].tolist()))
-            print(f"Generated {y.size(1) - prompt_length} tokens")
-            # Find all <s:2769>
-            codes = y[0, prompt_length:-1]
-            codes = codes - 32311
-            # print(codes)
-            assert (codes >= 0).all()
-            import numpy as np
-
-            np.save(f"codes_{i}.npy", codes[None].cpu().numpy())
-        else:
-            print()
         tokens_generated = y.size(1) - prompt_length
         tokens_sec = tokens_generated / t
-        aggregate_metrics["tokens_per_sec"].append(tokens_sec)
-        print(
-            f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec"
+        logger.info(
+            f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
+        )
+        logger.info(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
+        logger.info(
+            f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
         )
-        print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
-    print("==========")
 
-    print(
-        f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}"
-    )
-    print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
+        codes = y[1:, prompt_length:-1]
+        codes = codes - 2
+        assert (codes >= 0).all(), "Codes should be >= 0"
 
+        np.save(f"codes_{i}.npy", codes.cpu().numpy())
+        logger.info(f"Saved codes to codes_{i}.npy")
 
-if __name__ == "__main__":
-    import argparse
 
-    parser = argparse.ArgumentParser(description="Your CLI description.")
-
-    parser.add_argument(
-        "--prompt",
-        type=str,
-        default="感情分析関数では、大規模言語モデルに古典的な散文を分析させます。 分析の観点は比較的単純ですが、論理的な誤りはなく、依然として自己一貫性があることがわかります。",
-        help="Input prompt.",
-    )
-    parser.add_argument(
-        "--interactive",
-        action="store_true",
-        help="Whether to launch in interactive mode",
-    )
-    parser.add_argument("--num_samples", type=int, default=1, help="Number of samples.")
-    parser.add_argument(
-        "--max_new_tokens", type=int, default=4096, help="Maximum number of new tokens."
-    )
-    parser.add_argument("--top_k", type=int, default=50, help="Top-k for sampling.")
-    parser.add_argument("--top_p", type=int, default=0.95, help="Top-k for sampling.")
-    parser.add_argument("--repetition_penalty", type=float, default=1.0)
-    parser.add_argument(
-        "--temperature", type=float, default=0.8, help="Temperature for sampling."
-    )
-    parser.add_argument(
-        "--checkpoint_path",
-        type=Path,
-        default=Path("results/text2semantic_400m/step_000095000_weights.ckpt"),
-        help="Model checkpoint path.",
-    )
-    parser.add_argument(
-        "--compile", action="store_true", help="Whether to compile the model."
-    )
-    parser.add_argument("--profile", type=Path, default=None, help="Profile path.")
-
-    args = parser.parse_args()
-    main(
-        args.prompt,
-        args.interactive,
-        args.num_samples,
-        args.max_new_tokens,
-        args.top_k,
-        args.top_p,
-        args.repetition_penalty,
-        args.temperature,
-        args.checkpoint_path,
-        args.compile,
-        args.profile,
-    )
+if __name__ == "__main__":
+    main()

+ 0 - 170
tools/llama/tp.py

@@ -1,170 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-import os
-from typing import List, Optional
-
-import torch
-import torch.distributed as dist
-from quantize import WeightOnlyInt4Linear
-from torch import nn
-from torch.distributed import _functional_collectives as funcol
-
-from fish_speech.models.text2semantic.llama import Attention, FeedForward, Transformer
-
-
-def _get_rank() -> int:
-    return int(os.environ.get("LOCAL_RANK", "0"))
-
-
-def is_local():
-    return _get_rank() == 0
-
-
-def local_break():
-    if is_local():
-        breakpoint()
-    dist.barrier()
-
-
-def _get_world_size() -> int:
-    return int(os.environ.get("LOCAL_WORLD_SIZE", "1"))
-
-
-def maybe_init_dist() -> Optional[int]:
-    try:
-        # provided by torchrun
-        rank = _get_rank()
-        world_size = _get_world_size()
-
-        if world_size < 2:
-            # too few gpus to parallelize, tp is no-op
-            return None
-    except KeyError:
-        # not run via torchrun, no-op
-        return None
-
-    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
-    return rank
-
-
-def _apply_tp_linear(
-    linear: nn.Linear, style: str, weight_splits: List[int] = []
-) -> None:
-    rank = _get_rank()
-    world_size = _get_world_size()
-
-    # Linear's weight matrix is transposed, and is of shape
-    # (linear.out_features, linear.in_features)
-    dim_lookup = {"colwise": (0, "out_features"), "rowwise": (1, "in_features")}
-    assert style in dim_lookup
-    shard_dim, size_attr = dim_lookup[style]
-
-    # ensure we can shard evenly
-    assert getattr(linear, size_attr) % world_size == 0
-
-    def shard(x, dim):
-        assert x.size(dim=dim) % world_size == 0
-        return torch.tensor_split(x, world_size, dim=dim)[rank]
-
-    def shard_qkv(qkv, dim, weight_splits):
-        q, k, v = qkv.split(weight_splits, dim=dim)
-        q = shard(q, dim)
-        k = shard(k, dim)
-        v = shard(v, dim)
-        return torch.cat((q, k, v), dim=dim)
-
-    # shard
-    if weight_splits:
-        # attention
-        assert len(weight_splits) == 3
-
-        if isinstance(linear, WeightOnlyInt4Linear):
-            sharded_weight = shard_qkv(
-                linear.weight, shard_dim, [i // 8 for i in weight_splits]
-            )
-            linear.scales_and_zeros = shard_qkv(
-                linear.scales_and_zeros, 1 - shard_dim, weight_splits
-            )
-        else:
-            sharded_weight = shard_qkv(linear.weight, shard_dim, weight_splits)
-        if hasattr(linear, "scales") and style == "colwise":
-            linear.scales = shard_qkv(linear.scales, 0, weight_splits)
-    else:
-        sharded_weight = shard(linear.weight, shard_dim)
-        if isinstance(linear, WeightOnlyInt4Linear):
-            linear.scales_and_zeros = shard(linear.scales_and_zeros, 1 - shard_dim)
-            if style == "rowwise":
-                assert (
-                    linear.scales_and_zeros.shape[0] * 32
-                    == sharded_weight.shape[1]
-                    * sharded_weight.shape[2]
-                    * sharded_weight.shape[3]
-                )
-                assert linear.scales_and_zeros.shape[1] == sharded_weight.shape[0] * 8
-        if hasattr(linear, "scales") and style == "colwise":
-            linear.scales = shard(linear.scales, 0)
-
-    # local_break()
-    linear.weight = nn.Parameter(sharded_weight, requires_grad=False)
-    setattr(linear, size_attr, getattr(linear, size_attr) // world_size)
-
-    # shape info should still be synced
-    # assert linear.weight.shape == (linear.out_features, linear.in_features)
-
-
-def _apply_tp_ffn(mlp: FeedForward) -> None:
-    assert hasattr(mlp, "w1")
-    assert hasattr(mlp, "w3")
-    assert hasattr(mlp, "w2")
-
-    _apply_tp_linear(mlp.w1, "colwise")
-    _apply_tp_linear(mlp.w3, "colwise")
-    _apply_tp_linear(mlp.w2, "rowwise")
-
-    world_size = _get_world_size()
-    mlp.register_forward_hook(
-        lambda _module, _input, output: funcol.all_reduce(
-            output, "sum", list(range(world_size))
-        )
-    )
-
-
-def _apply_tp_attn(attn: Attention) -> None:
-    assert hasattr(attn, "wqkv")
-    assert hasattr(attn, "wo")
-
-    kv_size = attn.n_local_heads * attn.head_dim
-    _apply_tp_linear(attn.wqkv, "colwise", [attn.dim, kv_size, kv_size])
-    _apply_tp_linear(attn.wo, "rowwise")
-
-    # overwrite
-    world_size = _get_world_size()
-    attn.n_head = attn.n_head // world_size
-    attn.dim = attn.dim // world_size
-    attn.head_dim = attn.dim // attn.n_head
-    attn.n_local_heads = attn.n_local_heads // world_size
-
-    attn.register_forward_hook(
-        lambda _module, _input, output: funcol.all_reduce(
-            output[0], "sum", list(range(world_size))
-        )
-    )
-
-
-def _apply_tp_Transformer(Transformer: Transformer) -> None:
-    # overwrite config before Transformer.setup_cache is called
-    world_size = _get_world_size()
-    Transformer.config.n_head = Transformer.config.n_head // world_size
-    Transformer.config.dim = Transformer.config.dim // world_size
-    Transformer.config.n_local_heads = Transformer.config.n_local_heads // world_size
-
-
-def apply_tp(model: Transformer) -> None:
-    _apply_tp_Transformer(model)
-    for block in model.layers:
-        # Apply to MLP
-        _apply_tp_ffn(block.feed_forward)
-        _apply_tp_attn(block.attention)