|
|
@@ -235,7 +235,7 @@ def decode_n_tokens(
|
|
|
]
|
|
|
new_tokens.append(next_token)
|
|
|
f_end = time.perf_counter()
|
|
|
- logger.info(f"num_new_tokens for elapse: {f_end - f_start}")
|
|
|
+ # logger.info(f"num_new_tokens for elapse: {f_end - f_start}")
|
|
|
|
|
|
if cur_token[0, 0, -1] == im_end_id:
|
|
|
break
|
|
|
@@ -391,6 +391,12 @@ def generate(
|
|
|
|
|
|
|
|
|
def init_model(checkpoint_path, device, precision, compile=False):
|
|
|
+
|
|
|
+ torch.backends.cuda.enable_flash_sdp(False)
|
|
|
+ torch.backends.cuda.enable_math_sdp(False)
|
|
|
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
|
|
|
+ torch.backends.cuda.enable_cudnn_sdp(True)
|
|
|
+
|
|
|
model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
|
|
|
|
|
|
logger.info(f"precision: {precision.__class__.__name__}")
|