Bläddra i källkod

Fix quanitzation & remove memory usage

Lengyue 1 år sedan
förälder
incheckning
691b3bb937
3 ändrade filer med 38 tillägg och 60 borttagningar
  1. 4 4
      tools/llama/generate.py
  2. 25 53
      tools/llama/quantize.py
  3. 9 3
      tools/webui.py

+ 4 - 4
tools/llama/generate.py

@@ -367,7 +367,7 @@ def load_model(
 
     if "int8" in str(checkpoint_path):
         logger.info("Using int8 weight-only quantization!")
-        from quantize import WeightOnlyInt8QuantHandler
+        from .quantize import WeightOnlyInt8QuantHandler
 
         simple_quantizer = WeightOnlyInt8QuantHandler(model)
         model = simple_quantizer.convert_for_runtime()
@@ -377,7 +377,7 @@ def load_model(
         path_comps = checkpoint_path.name.split(".")
         assert path_comps[-2].startswith("g")
         groupsize = int(path_comps[-2][1:])
-        from quantize import WeightOnlyInt4QuantHandler
+        from .quantize import WeightOnlyInt4QuantHandler
 
         simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
         model = simple_quantizer.convert_for_runtime()
@@ -669,9 +669,9 @@ def launch_thread_safe_queue(
 @click.option(
     "--checkpoint-path",
     type=click.Path(path_type=Path, exists=True),
-    default="results/text2semantic_400m_finetune/step_000002000.pth",
+    default="checkpoints/text2semantic-sft-medium-v1-4k.pth",
 )
-@click.option("--config-name", type=str, default="dual_ar_8_codebook_small")
+@click.option("--config-name", type=str, default="dual_ar_2_codebook_medium")
 @click.option("--tokenizer", type=str, default="fishaudio/fish-speech-1")
 @click.option("--compile/--no-compile", default=False)
 @click.option("--seed", type=int, default=42)

+ 25 - 53
tools/llama/quantize.py

@@ -6,11 +6,12 @@
 import time
 from pathlib import Path
 
+import click
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
-from fish_speech.models.text2semantic.llama import ModelArgs, Transformer, find_multiple
+from .generate import load_model
 
 ##### Quantization Primitives ######
 
@@ -414,11 +415,21 @@ class WeightOnlyInt4Linear(torch.nn.Module):
         )
 
 
+@click.command()
+@click.option(
+    "--checkpoint-path",
+    type=click.Path(path_type=Path, exists=True),
+    default="checkpoints/text2semantic-sft-medium-v1-4k.pth",
+)
+@click.option("--config-name", type=str, default="dual_ar_2_codebook_medium")
+@click.option(
+    "--mode", type=str, default="int8", help="type of quantization to perform"
+)
+@click.option(
+    "--groupsize", type=int, default=128, help="Group size for int4 quantization."
+)
 def quantize(
-    checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"),
-    mode: str = "int8",
-    # following arguments only available when setting int4 quantization.
-    groupsize: int = 128,
+    checkpoint_path: Path, config_name: str, mode: str, groupsize: int
 ) -> None:
     assert checkpoint_path.is_file(), checkpoint_path
 
@@ -428,31 +439,14 @@ def quantize(
     print("Loading model ...")
     t0 = time.time()
 
-    with torch.device("meta"):
-        model = Transformer(
-            ModelArgs(
-                max_seq_len=4096,
-                vocab_size=36408,
-                n_layer=24,
-                n_head=16,
-                dim=1024,
-                rope_base=10000,
-                norm_eps=1e-5,
-                num_codebooks=4,  # single codebook
-                codebook_size=168,  # codebook size 160 + 2 special tokens
-            )
-        )
-
-    checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
-    if "state_dict" in checkpoint:
-        checkpoint = checkpoint["state_dict"]
-    checkpoint = {
-        k.replace("model.", ""): v
-        for k, v in checkpoint.items()
-        if k.startswith("model.")
-    }
-    model.load_state_dict(checkpoint, assign=True)
-    model = model.to(dtype=precision, device=device)
+    model, _ = load_model(
+        config_name,
+        checkpoint_path=checkpoint_path,
+        device=device,
+        precision=precision,
+        compile=False,
+        max_length=2048,
+    )
 
     if mode == "int8":
         print(
@@ -490,26 +484,4 @@ def quantize(
 
 
 if __name__ == "__main__":
-    import argparse
-
-    parser = argparse.ArgumentParser(description="Quantize a model.")
-    parser.add_argument(
-        "--checkpoint_path",
-        type=Path,
-        default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"),
-        help="Path to the model checkpoint to be quantized.",
-    )
-    parser.add_argument(
-        "--mode",
-        "-q",
-        type=str,
-        default="int8",
-        choices=["int8", "int4"],
-        help="type of quantization to perform",
-    )
-    parser.add_argument(
-        "--groupsize", type=int, default=32, help="Group size for int4 quantization."
-    )
-
-    args = parser.parse_args()
-    quantize(args.checkpoint_path, args.mode, args.groupsize)
+    quantize()

+ 9 - 3
tools/webui.py

@@ -138,9 +138,15 @@ def inference(
 
         # VQGAN Inference
         feature_lengths = torch.tensor([result.shape[1]], device=vqgan_model.device)
-        fake_audios = vqgan_model.decode(
-            indices=result[None], feature_lengths=feature_lengths, return_audios=True
-        )[0, 0]
+
+        with torch.autocast(
+            device_type=feature_lengths.device.type, dtype=args.precision
+        ):
+            fake_audios = vqgan_model.decode(
+                indices=result[None],
+                feature_lengths=feature_lengths,
+                return_audios=True,
+            )[0, 0]
         fake_audios = fake_audios.float().cpu().numpy()
 
         if streaming: