|
|
@@ -13,7 +13,6 @@ import torch
|
|
|
import torch._dynamo.config
|
|
|
import torch._inductor.config
|
|
|
from loguru import logger
|
|
|
-from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
|
from tqdm import tqdm
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
@@ -24,12 +23,7 @@ from fish_speech.conversation import (
|
|
|
TextPart,
|
|
|
VQPart,
|
|
|
)
|
|
|
-from fish_speech.models.text2semantic.llama import (
|
|
|
- BaseModelArgs,
|
|
|
- BaseTransformer,
|
|
|
- DualARTransformer,
|
|
|
- NaiveTransformer,
|
|
|
-)
|
|
|
+from fish_speech.models.text2semantic.llama import BaseModelArgs
|
|
|
from fish_speech.text import clean_text, split_text
|
|
|
from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
|
|
|
|
|
|
@@ -42,6 +36,15 @@ if hasattr(torch._inductor.config, "fx_graph_cache"):
|
|
|
torch._inductor.config.fx_graph_cache = True
|
|
|
|
|
|
|
|
|
+from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
|
+
|
|
|
+from fish_speech.models.text2semantic.llama import (
|
|
|
+ BaseTransformer,
|
|
|
+ DualARTransformer,
|
|
|
+ NaiveTransformer,
|
|
|
+)
|
|
|
+
|
|
|
+
|
|
|
def multinomial_sample_one_no_sync(
|
|
|
probs_sort,
|
|
|
): # Does multinomial sampling without a cuda synchronization
|
|
|
@@ -369,12 +372,8 @@ def decode_n_tokens(
|
|
|
window = previous_tokens[:, i - win_size : i]
|
|
|
|
|
|
with (
|
|
|
- sdpa_kernel(
|
|
|
- [
|
|
|
- SDPBackend.FLASH_ATTENTION,
|
|
|
- SDPBackend.EFFICIENT_ATTENTION,
|
|
|
- SDPBackend.MATH,
|
|
|
- ]
|
|
|
+ torch.backends.cuda.sdp_kernel(
|
|
|
+ enable_flash=False, enable_mem_efficient=False, enable_math=True
|
|
|
)
|
|
|
if torch.cuda.is_available()
|
|
|
else nullcontext()
|