|
@@ -40,10 +40,9 @@ from fish_speech.models.text2semantic.llama import (
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
-def multinomial_sample_one_no_sync(
|
|
|
|
|
- probs_sort,
|
|
|
|
|
-): # Does multinomial sampling without a cuda synchronization
|
|
|
|
|
- q = torch.empty_like(probs_sort).exponential_(1)
|
|
|
|
|
|
|
+def multinomial_sample_one_no_sync(probs_sort):
|
|
|
|
|
+ q = torch.rand_like(probs_sort)
|
|
|
|
|
+ q = -torch.log(q)
|
|
|
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
|
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
|
|
|
|
|
|
|
|
|
|
|
@@ -56,19 +55,22 @@ def logits_to_probs(
|
|
|
logits,
|
|
logits,
|
|
|
temperature: torch.Tensor,
|
|
temperature: torch.Tensor,
|
|
|
top_p: torch.Tensor,
|
|
top_p: torch.Tensor,
|
|
|
- top_k: torch.Tensor,
|
|
|
|
|
|
|
+ top_k: int, # 注意: 我看到你传进来的是 int,这很关键
|
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
|
- # Sort and compute top-p mask
|
|
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
|
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
|
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
|
|
- sorted_indices_to_remove = cum_probs > top_p
|
|
|
|
|
- # top-k mask
|
|
|
|
|
- sorted_indices_to_remove[top_k:] = True
|
|
|
|
|
- sorted_indices_to_remove[0] = False # keep at least one option
|
|
|
|
|
|
|
+
|
|
|
|
|
+ indices = torch.arange(sorted_logits.shape[-1], device=sorted_logits.device)
|
|
|
|
|
+ top_k_mask = indices >= top_k
|
|
|
|
|
+ sorted_indices_to_remove = (cum_probs > top_p) | top_k_mask
|
|
|
|
|
+ sorted_indices_to_remove[0] = False # 单元素修改问题不大,或者写成 | (indices != 0)
|
|
|
|
|
+
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
|
|
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
|
|
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
|
|
|
)
|
|
)
|
|
|
- logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
|
|
|
|
|
|
+ logits = torch.where(
|
|
|
|
|
+ indices_to_remove, float("-Inf"), logits
|
|
|
|
|
+ ) # 同样替换 masked_fill_ 为 torch.where
|
|
|
logits = logits / torch.clip(temperature, min=1e-5)
|
|
logits = logits / torch.clip(temperature, min=1e-5)
|
|
|
|
|
|
|
|
probs = torch.nn.functional.softmax(logits, dim=-1)
|
|
probs = torch.nn.functional.softmax(logits, dim=-1)
|
|
@@ -143,19 +145,12 @@ def decode_one_token_ar(
|
|
|
|
|
|
|
|
codebooks = [main_token_normal]
|
|
codebooks = [main_token_normal]
|
|
|
|
|
|
|
|
- # Only clear cache for fast_layers, avoid clearing main model cache
|
|
|
|
|
- for layer in model.fast_layers:
|
|
|
|
|
- if hasattr(layer, "attention") and hasattr(layer.attention, "kv_cache"):
|
|
|
|
|
- layer.attention.kv_cache.k_cache.fill_(0)
|
|
|
|
|
- layer.attention.kv_cache.v_cache.fill_(0)
|
|
|
|
|
-
|
|
|
|
|
input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
|
|
input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
|
|
|
model.forward_generate_fast(hidden_states, input_pos)
|
|
model.forward_generate_fast(hidden_states, input_pos)
|
|
|
|
|
|
|
|
- # [MODIFIED] Access config instead of tokenizer
|
|
|
|
|
a = codebooks[0] - model.config.semantic_begin_id
|
|
a = codebooks[0] - model.config.semantic_begin_id
|
|
|
- a[a < 0] = 0
|
|
|
|
|
- a[a >= model.config.codebook_size] = 0
|
|
|
|
|
|
|
+ a = torch.clamp(a, min=0, max=model.config.codebook_size - 1)
|
|
|
|
|
+
|
|
|
hidden_states = model.fast_embeddings(a)
|
|
hidden_states = model.fast_embeddings(a)
|
|
|
codebooks.append(a)
|
|
codebooks.append(a)
|
|
|
|
|
|
|
@@ -390,7 +385,7 @@ def init_model(checkpoint_path, device, precision, compile=False):
|
|
|
decode_one_token = torch.compile(
|
|
decode_one_token = torch.compile(
|
|
|
decode_one_token,
|
|
decode_one_token,
|
|
|
backend="inductor" if torch.cuda.is_available() else "aot_eager",
|
|
backend="inductor" if torch.cuda.is_available() else "aot_eager",
|
|
|
- mode="reduce-overhead" if torch.cuda.is_available() else None,
|
|
|
|
|
|
|
+ mode="default" if torch.cuda.is_available() else None,
|
|
|
fullgraph=True,
|
|
fullgraph=True,
|
|
|
)
|
|
)
|
|
|
|
|
|