Kaynağa Gözat

feat:开启compile

zhaohaipeng 1 ay önce
ebeveyn
işleme
1eeca7de1d
1 değiştirilmiş dosya ile 10 ekleme ve 3 silme
  1. 10 3
      fish_speech/models/text2semantic/inference.py

+ 10 - 3
fish_speech/models/text2semantic/inference.py

@@ -393,7 +393,7 @@ def generate(
 def init_model(checkpoint_path, device, precision, compile=False):
 
     torch.backends.cuda.enable_flash_sdp(False)
-    torch.backends.cuda.enable_math_sdp(False)
+    torch.backends.cuda.enable_math_sdp(True)
     torch.backends.cuda.enable_mem_efficient_sdp(True)
     torch.backends.cuda.enable_cudnn_sdp(True)
 
@@ -421,11 +421,18 @@ def init_model(checkpoint_path, device, precision, compile=False):
 
     if compile:
         logger.info("Compiling function...")
+        # decode_one_token = torch.compile(
+        #     decode_one_token,
+        #     backend="inductor" if torch.cuda.is_available() else "aot_eager",
+        #     mode="default" if torch.cuda.is_available() else None,
+        #     fullgraph=True,
+        # )
+
         decode_one_token = torch.compile(
             decode_one_token,
             backend="inductor" if torch.cuda.is_available() else "aot_eager",
-            mode="default" if torch.cuda.is_available() else None,
-            fullgraph=True,
+            mode="reduce-overhead" if torch.cuda.is_available() else None,
+            fullgraph=False,
         )
 
     return model.eval(), decode_one_token