|
@@ -357,7 +357,7 @@ def generate(
|
|
|
return seq
|
|
return seq
|
|
|
|
|
|
|
|
|
|
|
|
|
-def init_model(checkpoint_path, device, precision, compile=False):
|
|
|
|
|
|
|
+def init_model(checkpoint_path, device, precision, compile=False, quantize=False):
|
|
|
model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
|
|
model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
|
|
|
|
|
|
|
|
logger.info(f"precision: {precision.__class__.__name__}")
|
|
logger.info(f"precision: {precision.__class__.__name__}")
|
|
@@ -365,9 +365,44 @@ def init_model(checkpoint_path, device, precision, compile=False):
|
|
|
model = model.to(device=device, dtype=precision)
|
|
model = model.to(device=device, dtype=precision)
|
|
|
logger.info(f"Restored model from checkpoint")
|
|
logger.info(f"Restored model from checkpoint")
|
|
|
|
|
|
|
|
|
|
+ # Apply INT8 quantization if requested
|
|
|
|
|
+ if quantize:
|
|
|
|
|
+ try:
|
|
|
|
|
+ import bitsandbytes as bnb
|
|
|
|
|
+ logger.info("Applying INT8 quantization with bitsandbytes...")
|
|
|
|
|
+
|
|
|
|
|
+ # Replace all Linear layers with 8-bit quantized versions
|
|
|
|
|
+ def replace_linear_with_int8(module):
|
|
|
|
|
+ for name, child in module.named_children():
|
|
|
|
|
+ if isinstance(child, torch.nn.Linear):
|
|
|
|
|
+ # Create 8-bit linear layer
|
|
|
|
|
+ int8_layer = bnb.nn.Linear8bitLt(
|
|
|
|
|
+ child.in_features,
|
|
|
|
|
+ child.out_features,
|
|
|
|
|
+ bias=child.bias is not None,
|
|
|
|
|
+ has_fp16_weights=False,
|
|
|
|
|
+ threshold=6.0
|
|
|
|
|
+ )
|
|
|
|
|
+ # Copy weights
|
|
|
|
|
+ int8_layer.weight = bnb.nn.Int8Params(
|
|
|
|
|
+ child.weight.data,
|
|
|
|
|
+ requires_grad=False,
|
|
|
|
|
+ has_fp16_weights=False
|
|
|
|
|
+ )
|
|
|
|
|
+ if child.bias is not None:
|
|
|
|
|
+ int8_layer.bias = child.bias
|
|
|
|
|
+ setattr(module, name, int8_layer)
|
|
|
|
|
+ else:
|
|
|
|
|
+ replace_linear_with_int8(child)
|
|
|
|
|
+
|
|
|
|
|
+ replace_linear_with_int8(model)
|
|
|
|
|
+ logger.info("INT8 quantization applied successfully")
|
|
|
|
|
+ except ImportError:
|
|
|
|
|
+ logger.error("bitsandbytes not installed. Install with: pip install bitsandbytes")
|
|
|
|
|
+ raise
|
|
|
|
|
+
|
|
|
if isinstance(model, DualARTransformer):
|
|
if isinstance(model, DualARTransformer):
|
|
|
decode_one_token = decode_one_token_ar
|
|
decode_one_token = decode_one_token_ar
|
|
|
- # prefill_n_tokens = decode_one_token_ar
|
|
|
|
|
logger.info("Using DualARTransformer")
|
|
logger.info("Using DualARTransformer")
|
|
|
else:
|
|
else:
|
|
|
raise ValueError("Unsupported model type")
|
|
raise ValueError("Unsupported model type")
|
|
@@ -380,7 +415,8 @@ def init_model(checkpoint_path, device, precision, compile=False):
|
|
|
# Mark whether cache has been initialized
|
|
# Mark whether cache has been initialized
|
|
|
model._cache_setup_done = False
|
|
model._cache_setup_done = False
|
|
|
|
|
|
|
|
- if compile:
|
|
|
|
|
|
|
+ # Disable compile if quantization is enabled (bitsandbytes INT8 is incompatible with torch.compile)
|
|
|
|
|
+ if compile and not quantize:
|
|
|
logger.info("Compiling function...")
|
|
logger.info("Compiling function...")
|
|
|
decode_one_token = torch.compile(
|
|
decode_one_token = torch.compile(
|
|
|
decode_one_token,
|
|
decode_one_token,
|
|
@@ -388,6 +424,8 @@ def init_model(checkpoint_path, device, precision, compile=False):
|
|
|
mode="default" if torch.cuda.is_available() else None,
|
|
mode="default" if torch.cuda.is_available() else None,
|
|
|
fullgraph=True,
|
|
fullgraph=True,
|
|
|
)
|
|
)
|
|
|
|
|
+ elif compile and quantize:
|
|
|
|
|
+ logger.warning("torch.compile disabled when quantization is enabled (bitsandbytes compatibility)")
|
|
|
|
|
|
|
|
return model.eval(), decode_one_token
|
|
return model.eval(), decode_one_token
|
|
|
|
|
|
|
@@ -775,6 +813,7 @@ def launch_thread_safe_queue(
|
|
|
precision,
|
|
precision,
|
|
|
compile: bool = False,
|
|
compile: bool = False,
|
|
|
num_workers: int = 1,
|
|
num_workers: int = 1,
|
|
|
|
|
+ quantize: bool = False,
|
|
|
):
|
|
):
|
|
|
input_queue = queue.Queue()
|
|
input_queue = queue.Queue()
|
|
|
init_events = [threading.Event() for _ in range(num_workers)]
|
|
init_events = [threading.Event() for _ in range(num_workers)]
|
|
@@ -782,7 +821,7 @@ def launch_thread_safe_queue(
|
|
|
def worker(worker_id, init_event):
|
|
def worker(worker_id, init_event):
|
|
|
logger.info(f"Worker {worker_id} starting, loading model...")
|
|
logger.info(f"Worker {worker_id} starting, loading model...")
|
|
|
model, decode_one_token = init_model(
|
|
model, decode_one_token = init_model(
|
|
|
- checkpoint_path, device, precision, compile=compile
|
|
|
|
|
|
|
+ checkpoint_path, device, precision, compile=compile, quantize=quantize
|
|
|
)
|
|
)
|
|
|
with torch.device(device):
|
|
with torch.device(device):
|
|
|
model.setup_caches(
|
|
model.setup_caches(
|