فهرست منبع

INT8量化 3-4个线程

supeng 3 هفته پیش
والد
کامیت
663f123f9d
6فایلهای تغییر یافته به همراه72 افزوده شده و 7 حذف شده
  1. 21 3
      deploy_multi_worker.sh
  2. 43 4
      fish_speech/models/text2semantic/inference.py
  3. 3 0
      pyproject.toml
  4. 1 0
      tools/api_server.py
  5. 1 0
      tools/server/api_utils.py
  6. 3 0
      tools/server/model_manager.py

+ 21 - 3
deploy_multi_worker.sh

@@ -1,6 +1,9 @@
 #!/bin/bash
 # 多worker部署脚本 - 在单台机器上启动API服务
-# 使用方法: ./deploy_multi_worker.sh [num_workers] [port]
+# 使用方法: ./deploy_multi_worker.sh [num_workers] [port] [gpu_id] [quantize]
+# 示例:
+#   ./deploy_multi_worker.sh 2 8080 0       # 2个worker, 不量化
+#   ./deploy_multi_worker.sh 3 8080 0 1     # 3个worker, INT8量化
 
 set -e
 
@@ -8,6 +11,7 @@ set -e
 NUM_WORKERS=${1:-2}  # 默认2个worker
 PORT=${2:-8080}      # 默认端口8080
 GPU_ID=${3:-0}       # 默认GPU 0
+QUANTIZE=${4:-0}     # 是否启用INT8量化 (0=否, 1=是)
 
 LLAMA_CHECKPOINT="checkpoints/s2-pro"
 DECODER_CHECKPOINT="checkpoints/s2-pro/codec.pth"
@@ -22,13 +26,27 @@ echo "========================================="
 echo "Workers: ${NUM_WORKERS}"
 echo "Port: ${PORT}"
 echo "GPU: ${GPU_ID}"
+echo "Precision: BFloat16 (default, better stability than FP16)"
+echo "Quantize (INT8): ${QUANTIZE}"
 echo "========================================="
 
+# 构建量化参数
+QUANTIZE_ARG=""
+if [ "${QUANTIZE}" = "1" ]; then
+    QUANTIZE_ARG="--quantize"
+    echo "INT8 quantization enabled: VRAM per worker ~6GB (was ~12GB with BF16)"
+    echo "Recommended workers with INT8: 3-4 per GPU"
+else
+    echo "BF16 mode: VRAM per worker ~10-12GB"
+    echo "Recommended workers with BF16: 2 per GPU"
+fi
+
 # 启动API服务
+# 注意:不使用 --half 参数,默认使用 bfloat16(数值稳定性更好)
 python tools/api_server.py \
   --listen 0.0.0.0:${PORT} \
   --llama-checkpoint-path ${LLAMA_CHECKPOINT} \
   --decoder-checkpoint-path ${DECODER_CHECKPOINT} \
-  --half \
   --workers 1 \
-  --num-workers ${NUM_WORKERS}
+  --num-workers ${NUM_WORKERS} \
+  ${QUANTIZE_ARG}

+ 43 - 4
fish_speech/models/text2semantic/inference.py

@@ -357,7 +357,7 @@ def generate(
     return seq
 
 
-def init_model(checkpoint_path, device, precision, compile=False):
+def init_model(checkpoint_path, device, precision, compile=False, quantize=False):
     model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
 
     logger.info(f"precision: {precision.__class__.__name__}")
@@ -365,9 +365,44 @@ def init_model(checkpoint_path, device, precision, compile=False):
     model = model.to(device=device, dtype=precision)
     logger.info(f"Restored model from checkpoint")
 
+    # Apply INT8 quantization if requested
+    if quantize:
+        try:
+            import bitsandbytes as bnb
+            logger.info("Applying INT8 quantization with bitsandbytes...")
+
+            # Replace all Linear layers with 8-bit quantized versions
+            def replace_linear_with_int8(module):
+                for name, child in module.named_children():
+                    if isinstance(child, torch.nn.Linear):
+                        # Create 8-bit linear layer
+                        int8_layer = bnb.nn.Linear8bitLt(
+                            child.in_features,
+                            child.out_features,
+                            bias=child.bias is not None,
+                            has_fp16_weights=False,
+                            threshold=6.0
+                        )
+                        # Copy weights
+                        int8_layer.weight = bnb.nn.Int8Params(
+                            child.weight.data,
+                            requires_grad=False,
+                            has_fp16_weights=False
+                        )
+                        if child.bias is not None:
+                            int8_layer.bias = child.bias
+                        setattr(module, name, int8_layer)
+                    else:
+                        replace_linear_with_int8(child)
+
+            replace_linear_with_int8(model)
+            logger.info("INT8 quantization applied successfully")
+        except ImportError:
+            logger.error("bitsandbytes not installed. Install with: pip install bitsandbytes")
+            raise
+
     if isinstance(model, DualARTransformer):
         decode_one_token = decode_one_token_ar
-        # prefill_n_tokens = decode_one_token_ar
         logger.info("Using DualARTransformer")
     else:
         raise ValueError("Unsupported model type")
@@ -380,7 +415,8 @@ def init_model(checkpoint_path, device, precision, compile=False):
     # Mark whether cache has been initialized
     model._cache_setup_done = False
 
-    if compile:
+    # Disable compile if quantization is enabled (bitsandbytes INT8 is incompatible with torch.compile)
+    if compile and not quantize:
         logger.info("Compiling function...")
         decode_one_token = torch.compile(
             decode_one_token,
@@ -388,6 +424,8 @@ def init_model(checkpoint_path, device, precision, compile=False):
             mode="default" if torch.cuda.is_available() else None,
             fullgraph=True,
         )
+    elif compile and quantize:
+        logger.warning("torch.compile disabled when quantization is enabled (bitsandbytes compatibility)")
 
     return model.eval(), decode_one_token
 
@@ -775,6 +813,7 @@ def launch_thread_safe_queue(
     precision,
     compile: bool = False,
     num_workers: int = 1,
+    quantize: bool = False,
 ):
     input_queue = queue.Queue()
     init_events = [threading.Event() for _ in range(num_workers)]
@@ -782,7 +821,7 @@ def launch_thread_safe_queue(
     def worker(worker_id, init_event):
         logger.info(f"Worker {worker_id} starting, loading model...")
         model, decode_one_token = init_model(
-            checkpoint_path, device, precision, compile=compile
+            checkpoint_path, device, precision, compile=compile, quantize=quantize
         )
         with torch.device(device):
             model.setup_caches(

+ 3 - 0
pyproject.toml

@@ -70,6 +70,9 @@ cu129 = [
   "torch==2.8.0",
   "torchaudio",
 ]
+quantization = [
+  "bitsandbytes>=0.41.0",
+]
 
 [tool.uv]
 override-dependencies = [

+ 1 - 0
tools/api_server.py

@@ -97,6 +97,7 @@ class API(ExceptionHandler):
             decoder_checkpoint_path=self.args.decoder_checkpoint_path,
             decoder_config_name=self.args.decoder_config_name,
             num_workers=self.args.num_workers,
+            quantize=self.args.quantize,
         )
 
         logger.info(f"Startup done, listening server at http://{self.args.listen}")

+ 1 - 0
tools/server/api_utils.py

@@ -39,6 +39,7 @@ def parse_args():
     parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
     parser.add_argument("--workers", type=int, default=1)
     parser.add_argument("--num-workers", type=int, default=1, help="Number of model worker threads for parallel inference")
+    parser.add_argument("--quantize", action="store_true", help="Enable INT8 quantization to reduce VRAM usage")
     parser.add_argument("--api-key", type=str, default=None)
 
     return parser.parse_args()

+ 3 - 0
tools/server/model_manager.py

@@ -19,6 +19,7 @@ class ModelManager:
         decoder_checkpoint_path: str,
         decoder_config_name: str,
         num_workers: int = 1,
+        quantize: bool = False,
     ) -> None:
 
         self.mode = mode
@@ -26,6 +27,7 @@ class ModelManager:
         self.half = half
         self.compile = compile
         self.num_workers = num_workers
+        self.quantize = quantize
 
         self.precision = torch.half if half else torch.bfloat16
 
@@ -66,6 +68,7 @@ class ModelManager:
                 precision=precision,
                 compile=compile,
                 num_workers=self.num_workers,
+                quantize=self.quantize,
             )
         else:
             raise ValueError(f"Invalid mode: {mode}")