|
|
@@ -7,10 +7,11 @@ import traceback
|
|
|
from copy import deepcopy
|
|
|
from dataclasses import dataclass
|
|
|
from pathlib import Path
|
|
|
-from typing import Callable, Literal, Optional, Tuple, Union, Any
|
|
|
+from typing import Callable, Literal, Optional, Tuple, Union
|
|
|
|
|
|
import click
|
|
|
import numpy as np
|
|
|
+import torch
|
|
|
import torch._inductor.config
|
|
|
from loguru import logger
|
|
|
from tqdm import tqdm
|
|
|
@@ -29,10 +30,13 @@ torch._inductor.config.triton.unique_kernel_names = True
|
|
|
if hasattr(torch._inductor.config, "fx_graph_cache"):
|
|
|
torch._inductor.config.fx_graph_cache = True
|
|
|
|
|
|
+
|
|
|
from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
|
|
|
|
from fish_speech.models.text2semantic.llama import (
|
|
|
+ BaseTransformer,
|
|
|
DualARTransformer,
|
|
|
+ NaiveTransformer,
|
|
|
)
|
|
|
|
|
|
|
|
|
@@ -245,16 +249,15 @@ def decode_n_tokens(
|
|
|
@torch.no_grad()
|
|
|
@torch.inference_mode()
|
|
|
def generate(
|
|
|
- *,
|
|
|
- model: DualARTransformer,
|
|
|
- prompt: torch.Tensor,
|
|
|
- max_new_tokens: int,
|
|
|
- audio_masks: torch.Tensor,
|
|
|
- audio_parts: torch.Tensor,
|
|
|
- prompt_tokens = None,
|
|
|
- decode_one_token=decode_one_token_ar,
|
|
|
- num_samples: int = 1,
|
|
|
- **sampling_kwargs,
|
|
|
+ *,
|
|
|
+ model: DualARTransformer,
|
|
|
+ prompt: torch.Tensor,
|
|
|
+ max_new_tokens: int,
|
|
|
+ audio_masks: torch.Tensor,
|
|
|
+ audio_parts: torch.Tensor,
|
|
|
+ decode_one_token=decode_one_token_ar,
|
|
|
+ num_samples: int = 1,
|
|
|
+ **sampling_kwargs,
|
|
|
):
|
|
|
"""
|
|
|
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
|
|
|
@@ -337,6 +340,7 @@ def generate(
|
|
|
step7 = time.perf_counter()
|
|
|
|
|
|
prefill_decode = decode_one_token_ar
|
|
|
+
|
|
|
first_token = prefill_decode(
|
|
|
model,
|
|
|
prompt.view(1, codebook_dim, -1),
|
|
|
@@ -355,32 +359,6 @@ def generate(
|
|
|
input_pos = torch.tensor([T], device=device, dtype=torch.int)
|
|
|
step9 = time.perf_counter()
|
|
|
|
|
|
- im_end_id = model.tokenizer.get_token_id(IM_END_TOKEN)
|
|
|
- codebook_dim = 1 + model.config.num_codebooks
|
|
|
- window_size = 64
|
|
|
-
|
|
|
- previous_tokens = torch.zeros(
|
|
|
- (1, window_size, codebook_dim),
|
|
|
- device=device,
|
|
|
- dtype=first_token.dtype,
|
|
|
- )
|
|
|
-
|
|
|
- # =========================
|
|
|
- # 1. warm start prompt
|
|
|
- # =========================
|
|
|
- if prompt_tokens is not None:
|
|
|
- # 确保 shape = [B, T, C]
|
|
|
- if prompt_tokens.dim() == 2:
|
|
|
- prompt_tokens = prompt_tokens.unsqueeze(0)
|
|
|
-
|
|
|
- T = min(prompt_tokens.size(1), window_size)
|
|
|
- previous_tokens[:, -T:] = prompt_tokens[:, -T:]
|
|
|
-
|
|
|
- # =========================
|
|
|
- # 2. insert first token
|
|
|
- # =========================
|
|
|
- previous_tokens[:, -1, :] = first_token.view(codebook_dim)
|
|
|
-
|
|
|
x = decode_n_tokens(
|
|
|
model,
|
|
|
first_token.view(1, codebook_dim, -1),
|
|
|
@@ -627,8 +605,6 @@ def generate_long(
|
|
|
# Build base conversation with system message
|
|
|
base_conversation = Conversation()
|
|
|
|
|
|
- all_codes = None
|
|
|
-
|
|
|
if use_prompt:
|
|
|
# Auto-add speaker tags to prompt texts that don't have them
|
|
|
tagged_prompt_text = []
|
|
|
@@ -750,8 +726,6 @@ def generate_long(
|
|
|
audio_masks=audio_masks,
|
|
|
audio_parts=audio_parts,
|
|
|
decode_one_token=decode_one_token,
|
|
|
- prompt_tokens=all_codes,
|
|
|
-
|
|
|
temperature=temperature,
|
|
|
top_p=top_p,
|
|
|
top_k=top_k,
|