Przeglądaj źródła

feat:修改mode

zhaohaipeng 1 miesiąc temu
rodzic
commit
6f2a4f76c0
1 zmienionych plików z 4 dodań i 4 usunięć
  1. 4 4
      fish_speech/models/text2semantic/inference.py

+ 4 - 4
fish_speech/models/text2semantic/inference.py

@@ -93,7 +93,7 @@ def sample(
     return idx_next, probs
 
 
-def decode_one_token_ar_old(
+def decode_one_token_ar(
         model: DualARTransformer,
         x: torch.Tensor,
         input_pos: torch.Tensor,
@@ -181,7 +181,7 @@ def decode_one_token_ar_old(
     return codebooks.T
 
 @torch.inference_mode()
-def decode_one_token_ar(
+def decode_one_token_ar_optimize(
     model,
     x,
     input_pos,
@@ -522,8 +522,8 @@ def init_model(checkpoint_path, device, precision, compile=False):
         decode_one_token = torch.compile(
             decode_one_token,
             backend="inductor" if torch.cuda.is_available() else "aot_eager",
-            mode="default" if torch.cuda.is_available() else None,
-            fullgraph=True,
+            mode="reduce-overhead" if torch.cuda.is_available() else None,
+            fullgraph=False,
         )
 
     return model.eval(), decode_one_token