|
|
@@ -13,6 +13,7 @@ import torch
|
|
|
import torch._dynamo.config
|
|
|
import torch._inductor.config
|
|
|
from loguru import logger
|
|
|
+from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
|
from tqdm import tqdm
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
@@ -23,7 +24,12 @@ from fish_speech.conversation import (
|
|
|
TextPart,
|
|
|
VQPart,
|
|
|
)
|
|
|
-from fish_speech.models.text2semantic.llama import BaseModelArgs
|
|
|
+from fish_speech.models.text2semantic.llama import (
|
|
|
+ BaseModelArgs,
|
|
|
+ BaseTransformer,
|
|
|
+ DualARTransformer,
|
|
|
+ NaiveTransformer,
|
|
|
+)
|
|
|
from fish_speech.text import clean_text, split_text
|
|
|
from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
|
|
|
|
|
|
@@ -36,15 +42,6 @@ if hasattr(torch._inductor.config, "fx_graph_cache"):
|
|
|
torch._inductor.config.fx_graph_cache = True
|
|
|
|
|
|
|
|
|
-from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
|
-
|
|
|
-from fish_speech.models.text2semantic.llama import (
|
|
|
- BaseTransformer,
|
|
|
- DualARTransformer,
|
|
|
- NaiveTransformer,
|
|
|
-)
|
|
|
-
|
|
|
-
|
|
|
def multinomial_sample_one_no_sync(
|
|
|
probs_sort,
|
|
|
): # Does multinomial sampling without a cuda synchronization
|
|
|
@@ -372,8 +369,12 @@ def decode_n_tokens(
|
|
|
window = previous_tokens[:, i - win_size : i]
|
|
|
|
|
|
with (
|
|
|
- torch.backends.cuda.sdp_kernel(
|
|
|
- enable_flash=False, enable_mem_efficient=False, enable_math=True
|
|
|
+ sdpa_kernel(
|
|
|
+ [
|
|
|
+ SDPBackend.FLASH_ATTENTION,
|
|
|
+ SDPBackend.EFFICIENT_ATTENTION,
|
|
|
+ SDPBackend.MATH,
|
|
|
+ ]
|
|
|
)
|
|
|
if torch.cuda.is_available()
|
|
|
else nullcontext()
|