|
@@ -362,6 +362,8 @@ def generate(
|
|
|
def init_model(checkpoint_path, device, precision, compile=False):
|
|
def init_model(checkpoint_path, device, precision, compile=False):
|
|
|
model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
|
|
model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
|
|
|
|
|
|
|
|
|
|
+ logger.info(f"precision: {precision.__class__.__name__}")
|
|
|
|
|
+
|
|
|
model = model.to(device=device, dtype=precision)
|
|
model = model.to(device=device, dtype=precision)
|
|
|
logger.info(f"Restored model from checkpoint")
|
|
logger.info(f"Restored model from checkpoint")
|
|
|
|
|
|