Procházet zdrojové kódy

Quantization Support (#316)

* Add Windows Setup Help

* Optimize documents/bootscripts for Windows User

* Correct some description

* Fix dependecies

* fish 1.2 webui & api

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix spelling

* Fix CUDA env

* Update api usage

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Adapt finetuning

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Quantization Support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama před 1 rokem
rodič
revize
ea53678446

+ 20 - 1
fish_speech/models/text2semantic/llama.py

@@ -7,6 +7,7 @@ from typing import Optional
 import torch
 import torch.nn as nn
 from einops import rearrange
+from loguru import logger
 from torch import Tensor
 from torch.nn import functional as F
 from torch.nn.attention import SDPBackend, sdpa_kernel
@@ -320,7 +321,7 @@ class BaseTransformer(nn.Module):
         lora_config: LoraConfig | None = None,
         rope_base: int | None = None,
     ) -> "BaseTransformer":
-        config = BaseModelArgs.from_pretrained(path)
+        config = BaseModelArgs.from_pretrained(str(path))
         if max_length is not None:
             config.max_seq_len = max_length
             log.info(f"Override max_seq_len to {max_length}")
@@ -348,6 +349,24 @@ class BaseTransformer(nn.Module):
         if load_weights is False:
             log.info("Randomly initialized model")
         else:
+
+            if "int8" in str(Path(path)):
+                logger.info("Using int8 weight-only quantization!")
+                from tools.llama.quantize import WeightOnlyInt8QuantHandler
+
+                simple_quantizer = WeightOnlyInt8QuantHandler(model)
+                model = simple_quantizer.convert_for_runtime()
+
+            if "int4" in str(Path(path)):
+                logger.info("Using int4 quantization!")
+                path_comps = path.name.split("-")
+                assert path_comps[-2].startswith("g")
+                groupsize = int(path_comps[-2][1:])
+                from tools.llama.quantize import WeightOnlyInt4QuantHandler
+
+                simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
+                model = simple_quantizer.convert_for_runtime()
+
             weights = torch.load(
                 Path(path) / "model.pth", map_location="cpu", mmap=True
             )

+ 71 - 10
fish_speech/webui/manage.py

@@ -555,16 +555,16 @@ def fresh_tb_dir():
 
 
 def list_decoder_models():
-    paths = [str(p) for p in Path("checkpoints").glob("vq*.*")] + [
-        str(p) for p in Path("results").glob("vqgan*/**/*.ckpt")
-    ]
+    paths = [str(p) for p in Path("checkpoints").glob("fish*/firefly*.pth")]
     if not paths:
         logger.warning("No decoder model found")
     return paths
 
 
 def list_llama_models():
-    choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*.pth")]
+    choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*model*.pth")]
+    choices += [str(p.parent) for p in Path("checkpoints").glob("fish*/*model*.pth")]
+    choices += [str(p.parent) for p in Path("checkpoints").glob("fs*/*model*.pth")]
     if not choices:
         logger.warning("No LLaMA model found")
     return choices
@@ -593,11 +593,7 @@ def fresh_llama_ckpt(llama_use_lora):
 
 
 def fresh_llama_model():
-    choices = [
-        str(p).replace("\\", "/") for p in Path("checkpoints").glob("text2sem*.*")
-    ]
-    choices += [str(p) for p in Path("results").glob("text2sem*/**/*.ckpt")]
-    return gr.Dropdown(choices=choices)
+    return gr.Dropdown(choices=list_llama_models())
 
 
 def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_output):
@@ -627,6 +623,39 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou
     return build_html_ok_message(i18n("Merge successfully"))
 
 
+def llama_quantify(llama_weight, quantify_mode):
+    if llama_weight is None or not Path(llama_weight).exists():
+        return build_html_error_message(
+            i18n(
+                "Path error, please check the model file exists in the corresponding path"
+            )
+        )
+    now = generate_folder_name()
+    quantify_cmd = [
+        PYTHON,
+        "tools/llama/quantize.py",
+        "--checkpoint-path",
+        llama_weight,
+        "--mode",
+        quantify_mode,
+        "--timestamp",
+        now,
+    ]
+    logger.info(quantify_cmd)
+    subprocess.run(quantify_cmd)
+    if quantify_mode == "int8":
+        quantize_path = str(
+            Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-{now}"
+        )
+    else:
+        quantize_path = str(
+            Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-g128-{now}"
+        )
+    return build_html_ok_message(
+        i18n("Quantify successfully") + f"Path: {quantize_path}"
+    )
+
+
 init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path)
 init_llama_yml = load_yaml_data_in_fact(llama_yml_path)
 
@@ -907,7 +936,34 @@ with gr.Blocks(
                                     value=i18n("Merge"), variant="primary"
                                 )
 
-                        with gr.Tab(label="Tensorboard", id=5):
+                        with gr.Tab(label=i18n("Model Quantization"), id=5):
+                            with gr.Row(equal_height=False):
+                                llama_weight_to_quantify = gr.Dropdown(
+                                    label=i18n("Base LLAMA Model"),
+                                    info=i18n(
+                                        "Type the path or select from the dropdown"
+                                    ),
+                                    choices=list_llama_models(),
+                                    value="checkpoints/fish-speech-1.2",
+                                    allow_custom_value=True,
+                                    interactive=True,
+                                )
+                                quantify_mode = gr.Dropdown(
+                                    label=i18n("Post-quantification Precision"),
+                                    info=i18n(
+                                        "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase"
+                                    ),
+                                    choices=["int8", "int4"],
+                                    value="int8",
+                                    allow_custom_value=False,
+                                    interactive=True,
+                                )
+                            with gr.Row(equal_height=False):
+                                llama_quantify_btn = gr.Button(
+                                    value=i18n("Quantify"), variant="primary"
+                                )
+
+                        with gr.Tab(label="Tensorboard", id=6):
                             with gr.Row(equal_height=False):
                                 tb_host = gr.Textbox(
                                     label=i18n("Tensorboard Host"), value="127.0.0.1"
@@ -1122,6 +1178,11 @@ with gr.Blocks(
         inputs=[llama_weight, lora_llama_config, lora_weight, llama_lora_output],
         outputs=[train_error],
     )
+    llama_quantify_btn.click(
+        fn=llama_quantify,
+        inputs=[llama_weight_to_quantify, quantify_mode],
+        outputs=[train_error],
+    )
     infer_checkbox.change(
         fn=change_infer,
         inputs=[

+ 0 - 17
tools/llama/generate.py

@@ -343,23 +343,6 @@ def load_model(checkpoint_path, device, precision, compile=False):
         checkpoint_path, load_weights=True
     )
 
-    if "int8" in str(checkpoint_path):
-        logger.info("Using int8 weight-only quantization!")
-        from .quantize import WeightOnlyInt8QuantHandler
-
-        simple_quantizer = WeightOnlyInt8QuantHandler(model)
-        model = simple_quantizer.convert_for_runtime()
-
-    if "int4" in str(checkpoint_path):
-        logger.info("Using int4 quantization!")
-        path_comps = checkpoint_path.name.split(".")
-        assert path_comps[-2].startswith("g")
-        groupsize = int(path_comps[-2][1:])
-        from .quantize import WeightOnlyInt4QuantHandler
-
-        simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
-        model = simple_quantizer.convert_for_runtime()
-
     model = model.to(device=device, dtype=precision)
     logger.info(f"Restored model from checkpoint")
 

+ 24 - 16
tools/llama/quantize.py

@@ -1,5 +1,7 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # All rights reserved.
+import datetime
+import shutil
 
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
@@ -11,7 +13,8 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
-from .generate import load_model
+from fish_speech.models.text2semantic.llama import find_multiple
+from tools.llama.generate import load_model
 
 ##### Quantization Primitives ######
 
@@ -415,23 +418,26 @@ class WeightOnlyInt4Linear(torch.nn.Module):
         )
 
 
+def generate_folder_name():
+    now = datetime.datetime.now()
+    folder_name = now.strftime("%Y%m%d_%H%M%S")
+    return folder_name
+
+
 @click.command()
 @click.option(
     "--checkpoint-path",
     type=click.Path(path_type=Path, exists=True),
     default="checkpoints/fish-speech-1.2",
 )
-@click.option("--config-name", type=str, default="dual_ar_2_codebook_medium")
 @click.option(
     "--mode", type=str, default="int8", help="type of quantization to perform"
 )
 @click.option(
     "--groupsize", type=int, default=128, help="Group size for int4 quantization."
 )
-def quantize(
-    checkpoint_path: Path, config_name: str, mode: str, groupsize: int
-) -> None:
-    assert checkpoint_path.is_file(), checkpoint_path
+@click.option("--timestamp", type=str, default="None", help="When to do quantization")
+def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -> None:
 
     device = "cpu"
     precision = torch.bfloat16
@@ -440,13 +446,13 @@ def quantize(
     t0 = time.time()
 
     model, _ = load_model(
-        config_name,
         checkpoint_path=checkpoint_path,
         device=device,
         precision=precision,
         compile=False,
-        max_length=2048,
     )
+    vq_model = "firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
+    now = timestamp if timestamp != "None" else generate_folder_name()
 
     if mode == "int8":
         print(
@@ -455,10 +461,11 @@ def quantize(
         quant_handler = WeightOnlyInt8QuantHandler(model)
         quantized_state_dict = quant_handler.create_quantized_state_dict()
 
-        dir_name = checkpoint_path.parent
-        base_name = checkpoint_path.stem
-        suffix = checkpoint_path.suffix
-        quantize_path = dir_name / f"{base_name}.int8{suffix}"
+        dir_name = checkpoint_path
+        dst_name = Path(f"checkpoints/fs-1.2-int8-{now}")
+        shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
+        (dst_name / vq_model).unlink()
+        quantize_path = dst_name / "model.pth"
 
     elif mode == "int4":
         print(
@@ -467,10 +474,11 @@ def quantize(
         quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
         quantized_state_dict = quant_handler.create_quantized_state_dict()
 
-        dir_name = checkpoint_path.parent
-        base_name = checkpoint_path.name
-        suffix = checkpoint_path.suffix
-        quantize_path = dir_name / f"{base_name}.int4.g{groupsize}{suffix}"
+        dir_name = checkpoint_path
+        dst_name = Path(f"checkpoints/fs-1.2-int4-g{groupsize}-{now}")
+        shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
+        (dst_name / vq_model).unlink()
+        quantize_path = dst_name / "model.pth"
 
     else:
         raise ValueError(