Jelajahi Sumber

feat:开启compile

zhaohaipeng 1 bulan lalu
induk
melakukan
7462a337a4
2 mengubah file dengan 8 tambahan dan 2 penghapusan
  1. 1 1
      .env
  2. 7 1
      fish_speech/models/text2semantic/inference.py

+ 1 - 1
.env

@@ -1,4 +1,4 @@
 API_PORT=8080
-COMPILE=0
+COMPILE=1
 HALF=1
 PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:64,expandable_segments:True

+ 7 - 1
fish_speech/models/text2semantic/inference.py

@@ -235,7 +235,7 @@ def decode_n_tokens(
         ]
         new_tokens.append(next_token)
         f_end = time.perf_counter()
-        logger.info(f"num_new_tokens for elapse: {f_end - f_start}")
+        # logger.info(f"num_new_tokens for elapse: {f_end - f_start}")
 
         if cur_token[0, 0, -1] == im_end_id:
             break
@@ -391,6 +391,12 @@ def generate(
 
 
 def init_model(checkpoint_path, device, precision, compile=False):
+
+    torch.backends.cuda.enable_flash_sdp(False)
+    torch.backends.cuda.enable_math_sdp(False)
+    torch.backends.cuda.enable_mem_efficient_sdp(True)
+    torch.backends.cuda.enable_cudnn_sdp(True)
+
     model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
 
     logger.info(f"precision: {precision.__class__.__name__}")