|
|
@@ -17,9 +17,16 @@ from loguru import logger
|
|
|
from tqdm import tqdm
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
-from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
|
|
|
+from fish_speech.conversation import (
|
|
|
+ CODEBOOK_PAD_TOKEN_ID,
|
|
|
+ Conversation,
|
|
|
+ Message,
|
|
|
+ TextPart,
|
|
|
+ VQPart,
|
|
|
+)
|
|
|
from fish_speech.models.text2semantic.llama import BaseModelArgs
|
|
|
from fish_speech.text import clean_text, split_text
|
|
|
+from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
|
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
torch._inductor.config.coordinate_descent_tuning = True
|
|
|
@@ -145,8 +152,8 @@ def decode_one_token_ar_agent(
|
|
|
model: DualARTransformer,
|
|
|
x: torch.Tensor,
|
|
|
input_pos: torch.Tensor,
|
|
|
+ semantic_ids: list,
|
|
|
previous_tokens: torch.Tensor = None,
|
|
|
- semantic_id: int = 32003,
|
|
|
**sampling_kwargs,
|
|
|
) -> torch.Tensor:
|
|
|
# print(x, input_pos)
|
|
|
@@ -190,19 +197,13 @@ def decode_one_token_ar_agent(
|
|
|
codebooks.append(a)
|
|
|
|
|
|
codebooks = torch.stack(codebooks, dim=1)
|
|
|
+ semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
|
|
|
codebooks[:, 1:, :] = torch.masked_fill(
|
|
|
- codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
|
|
|
+ codebooks[:, 1:, :],
|
|
|
+ ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
|
|
|
+ CODEBOOK_PAD_TOKEN_ID,
|
|
|
)
|
|
|
|
|
|
- # for i in range(codebooks.size(1) - 1):
|
|
|
- # codebooks[:, i + 1, :] = torch.masked_fill(
|
|
|
- # codebooks[:, i + 1, :],
|
|
|
- # codebooks[:, :1, :] != semantic_id,
|
|
|
- # CODEBOOK_PAD_TOKEN_ID + i * 1024,
|
|
|
- # )
|
|
|
-
|
|
|
- # print(codebooks)
|
|
|
-
|
|
|
return codebooks
|
|
|
|
|
|
|
|
|
@@ -210,8 +211,8 @@ def decode_one_token_naive_agent(
|
|
|
model: NaiveTransformer,
|
|
|
x: torch.Tensor,
|
|
|
input_pos: torch.Tensor,
|
|
|
+ semantic_ids: list,
|
|
|
previous_tokens: torch.Tensor = None,
|
|
|
- semantic_id: int = 32003,
|
|
|
**sampling_kwargs,
|
|
|
) -> torch.Tensor:
|
|
|
x = model.forward_generate(x, input_pos)
|
|
|
@@ -236,8 +237,11 @@ def decode_one_token_naive_agent(
|
|
|
)
|
|
|
|
|
|
codebooks = torch.stack(codebooks, dim=1)
|
|
|
+ semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
|
|
|
codebooks[:, 1:, :] = torch.masked_fill(
|
|
|
- codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
|
|
|
+ codebooks[:, 1:, :],
|
|
|
+ ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
|
|
|
+ CODEBOOK_PAD_TOKEN_ID,
|
|
|
)
|
|
|
|
|
|
return codebooks
|
|
|
@@ -247,8 +251,8 @@ def decode_one_token_ar(
|
|
|
model: DualARTransformer,
|
|
|
x: torch.Tensor,
|
|
|
input_pos: torch.Tensor,
|
|
|
+ semantic_ids: list,
|
|
|
previous_tokens: torch.Tensor = None,
|
|
|
- semantic_id: int = 0,
|
|
|
**sampling_kwargs,
|
|
|
) -> torch.Tensor:
|
|
|
x = model.forward_generate(x, input_pos)
|
|
|
@@ -261,21 +265,32 @@ def decode_one_token_ar(
|
|
|
codebooks = [
|
|
|
sample(
|
|
|
x.logits,
|
|
|
- previous_tokens=None, # Disable repetition penalty for the token codebook
|
|
|
+ previous_tokens=(
|
|
|
+ previous_tokens[0] if previous_tokens is not None else None
|
|
|
+ ), # Disable repetition penalty for the token codebook
|
|
|
**sampling_kwargs_main,
|
|
|
)[0]
|
|
|
]
|
|
|
|
|
|
- x = x.hidden_states
|
|
|
+ hidden_states = x.hidden_states
|
|
|
|
|
|
# Cleanup the cache
|
|
|
for layer in model.fast_layers:
|
|
|
layer.attention.kv_cache.k_cache.fill_(0)
|
|
|
layer.attention.kv_cache.v_cache.fill_(0)
|
|
|
|
|
|
- for codebook_idx in range(model.config.num_codebooks):
|
|
|
- input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
|
|
|
- logits = model.forward_generate_fast(x, input_pos)
|
|
|
+ input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
|
|
|
+ model.forward_generate_fast(hidden_states, input_pos)
|
|
|
+ a = codebooks[0] - model.tokenizer.semantic_begin_id
|
|
|
+ a[a < 0] = 0
|
|
|
+ hidden_states = model.fast_embeddings(a)
|
|
|
+ codebooks.append(a)
|
|
|
+
|
|
|
+ for codebook_idx in range(1, model.config.num_codebooks):
|
|
|
+ input_pos = torch.tensor(
|
|
|
+ [codebook_idx], device=hidden_states.device, dtype=torch.long
|
|
|
+ )
|
|
|
+ logits = model.forward_generate_fast(hidden_states, input_pos)
|
|
|
a = sample(
|
|
|
logits,
|
|
|
previous_tokens=(
|
|
|
@@ -285,14 +300,16 @@ def decode_one_token_ar(
|
|
|
),
|
|
|
**sampling_kwargs,
|
|
|
)[0]
|
|
|
- x = model.fast_embeddings(a)
|
|
|
+ hidden_states = model.fast_embeddings(a)
|
|
|
codebooks.append(a)
|
|
|
|
|
|
codebooks = torch.stack(codebooks, dim=0)
|
|
|
- codebooks[1:, :] = torch.masked_fill(
|
|
|
- codebooks[1:, :], codebooks[:1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
|
|
|
- )
|
|
|
+ # semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
|
|
|
+ # codebooks[1:, :] = torch.masked_fill(
|
|
|
+ # codebooks[1:, :], ~torch.isin(codebooks[:1, :], semantic_ids_tensor), CODEBOOK_PAD_TOKEN_ID
|
|
|
+ # )
|
|
|
|
|
|
+ # print(codebooks)
|
|
|
return codebooks
|
|
|
|
|
|
|
|
|
@@ -337,9 +354,8 @@ def decode_n_tokens(
|
|
|
cur_token: torch.Tensor,
|
|
|
input_pos: torch.Tensor,
|
|
|
num_new_tokens: int,
|
|
|
- im_end_id: int = 4,
|
|
|
+ semantic_ids: list,
|
|
|
decode_one_token=decode_one_token_naive,
|
|
|
- semantic_id: int = 0,
|
|
|
**sampling_kwargs,
|
|
|
):
|
|
|
previous_tokens = torch.zeros(
|
|
|
@@ -368,7 +384,7 @@ def decode_n_tokens(
|
|
|
x=cur_token,
|
|
|
input_pos=input_pos,
|
|
|
previous_tokens=window,
|
|
|
- semantic_id=semantic_id,
|
|
|
+ semantic_ids=semantic_ids,
|
|
|
**sampling_kwargs,
|
|
|
)
|
|
|
|
|
|
@@ -378,7 +394,7 @@ def decode_n_tokens(
|
|
|
model.config.num_codebooks + 1, -1
|
|
|
)
|
|
|
|
|
|
- if cur_token[0, 0, -1] == im_end_id:
|
|
|
+ if cur_token[0, 0, -1] == model.tokenizer.get_token_id(IM_END_TOKEN):
|
|
|
break
|
|
|
|
|
|
return previous_tokens[:, : i + 1]
|
|
|
@@ -391,7 +407,6 @@ def generate(
|
|
|
model: NaiveTransformer,
|
|
|
prompt: torch.Tensor,
|
|
|
max_new_tokens: int,
|
|
|
- im_end_id: int = 4,
|
|
|
decode_one_token=decode_one_token_naive,
|
|
|
**sampling_kwargs,
|
|
|
) -> torch.Tensor:
|
|
|
@@ -401,7 +416,10 @@ def generate(
|
|
|
|
|
|
# create an empty tensor of the expected final shape and fill in the current tokens
|
|
|
T = prompt.size(1)
|
|
|
- semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
|
|
|
+ # semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
|
|
|
+ semantic_ids = [
|
|
|
+ model.tokenizer.get_token_id(f"<|semantic:{i}|>") for i in range(1024)
|
|
|
+ ]
|
|
|
|
|
|
if max_new_tokens:
|
|
|
if T + max_new_tokens > model.config.max_seq_len:
|
|
|
@@ -435,7 +453,7 @@ def generate(
|
|
|
model,
|
|
|
prompt.view(1, codebook_dim, -1),
|
|
|
input_pos,
|
|
|
- semantic_id=semantic_id,
|
|
|
+ semantic_ids=semantic_ids,
|
|
|
**sampling_kwargs,
|
|
|
)
|
|
|
seq[:, T : T + 1] = next_token
|
|
|
@@ -446,9 +464,8 @@ def generate(
|
|
|
next_token.view(1, codebook_dim, -1),
|
|
|
input_pos,
|
|
|
max_new_tokens - 1,
|
|
|
- im_end_id=im_end_id,
|
|
|
decode_one_token=decode_one_token,
|
|
|
- semantic_id=semantic_id,
|
|
|
+ semantic_ids=semantic_ids,
|
|
|
**sampling_kwargs,
|
|
|
)
|
|
|
# x = torch.cat(generated_tokens, dim=1)
|
|
|
@@ -463,8 +480,8 @@ def decode_n_tokens_agent(
|
|
|
cur_token: torch.Tensor,
|
|
|
input_pos: torch.Tensor,
|
|
|
num_new_tokens: int,
|
|
|
+ semantic_ids: list,
|
|
|
im_end_id: int = 4,
|
|
|
- semantic_id: int = 32003,
|
|
|
decode_one_token=decode_one_token_naive_agent,
|
|
|
early_stop_threshold: float = 0.6,
|
|
|
**sampling_kwargs,
|
|
|
@@ -495,7 +512,7 @@ def decode_n_tokens_agent(
|
|
|
x=cur_token,
|
|
|
input_pos=input_pos,
|
|
|
previous_tokens=window,
|
|
|
- semantic_id=semantic_id,
|
|
|
+ semantic_ids=semantic_ids,
|
|
|
**sampling_kwargs,
|
|
|
)
|
|
|
|
|
|
@@ -529,8 +546,8 @@ def generate_agent(
|
|
|
model: BaseTransformer,
|
|
|
prompt: torch.Tensor,
|
|
|
max_new_tokens: int,
|
|
|
+ semantic_ids: list,
|
|
|
im_end_id: int = 4,
|
|
|
- semantic_id: int = 32003,
|
|
|
decode_one_token=decode_one_token_naive_agent,
|
|
|
num_samples: int = 1,
|
|
|
early_stop_threshold: float = 0.6,
|
|
|
@@ -574,7 +591,7 @@ def generate_agent(
|
|
|
model,
|
|
|
prompt,
|
|
|
input_pos,
|
|
|
- semantic_id=semantic_id,
|
|
|
+ semantic_ids=semantic_ids,
|
|
|
**sampling_kwargs,
|
|
|
).view(num_samples, codebook_dim, -1)
|
|
|
yield next_token.cpu()
|
|
|
@@ -587,7 +604,7 @@ def generate_agent(
|
|
|
input_pos,
|
|
|
max_new_tokens - 1,
|
|
|
im_end_id=im_end_id,
|
|
|
- semantic_id=semantic_id,
|
|
|
+ semantic_ids=semantic_ids,
|
|
|
decode_one_token=decode_one_token,
|
|
|
early_stop_threshold=early_stop_threshold,
|
|
|
**sampling_kwargs,
|
|
|
@@ -602,65 +619,63 @@ def encode_tokens(
|
|
|
num_codebooks=4,
|
|
|
):
|
|
|
string = clean_text(string)
|
|
|
- string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
|
|
|
|
|
|
- new_tokens = tokenizer.encode(
|
|
|
- string,
|
|
|
- add_special_tokens=False,
|
|
|
- max_length=10**6,
|
|
|
- truncation=False,
|
|
|
+ messages = []
|
|
|
+ messages.append(
|
|
|
+ Message(
|
|
|
+ role="user",
|
|
|
+ parts=[TextPart(text=string)],
|
|
|
+ cal_loss=False,
|
|
|
+ )
|
|
|
)
|
|
|
- tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
|
|
|
|
|
|
- # Codebooks
|
|
|
- zeros = (
|
|
|
- torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
|
|
|
- * CODEBOOK_PAD_TOKEN_ID
|
|
|
- )
|
|
|
- prompt = torch.cat((tokens, zeros), dim=0)
|
|
|
+ if prompt_tokens is not None:
|
|
|
+ if prompt_tokens.ndim == 3:
|
|
|
+ assert (
|
|
|
+ prompt_tokens.shape[0] == 1
|
|
|
+ ), "3D prompt tokens should have shape (1, num_codebooks, seq_len)"
|
|
|
+ prompt_tokens = prompt_tokens[0]
|
|
|
|
|
|
- if prompt_tokens is None:
|
|
|
- return prompt
|
|
|
+ assert prompt_tokens.ndim == 2, "Prompt tokens should be 2D tensor"
|
|
|
|
|
|
- # Get prompt tokens
|
|
|
- if prompt_tokens.ndim == 3:
|
|
|
- assert (
|
|
|
- prompt_tokens.shape[0] == 1
|
|
|
- ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
|
|
|
- prompt_tokens = prompt_tokens[0]
|
|
|
+ if prompt_tokens.shape[0] > num_codebooks:
|
|
|
+ logger.warning(
|
|
|
+ f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
|
|
|
+ )
|
|
|
+ prompt_tokens = prompt_tokens[:num_codebooks]
|
|
|
|
|
|
- assert prompt_tokens.ndim == 2
|
|
|
- data = prompt_tokens + 1
|
|
|
+ vq_part = VQPart(codes=prompt_tokens.to(device))
|
|
|
|
|
|
- if prompt_tokens.shape[0] > num_codebooks:
|
|
|
- logger.warning(
|
|
|
- f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
|
|
|
+ messages.append(
|
|
|
+ Message(
|
|
|
+ role="assistant",
|
|
|
+ parts=[TextPart(text="<|voice|>"), vq_part],
|
|
|
+ cal_loss=False,
|
|
|
+ )
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ messages.append(
|
|
|
+ Message(
|
|
|
+ role="assistant",
|
|
|
+ parts=[TextPart(text="<|voice|>")],
|
|
|
+ cal_loss=False,
|
|
|
+ add_im_end=False,
|
|
|
+ )
|
|
|
)
|
|
|
- data = data[:num_codebooks]
|
|
|
-
|
|
|
- # Add pad token for each codebook
|
|
|
- data = torch.cat(
|
|
|
- (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
|
|
|
- dim=1,
|
|
|
- )
|
|
|
|
|
|
- # Since 1.0, we use <|semantic|>
|
|
|
- s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
|
|
|
- end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
|
|
- main_token_ids = (
|
|
|
- torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
|
|
|
+ conversation = Conversation(messages=messages)
|
|
|
+ # conversation.visualize(tokenizer)
|
|
|
+ encoded = conversation.encode_for_inference(
|
|
|
+ tokenizer=tokenizer,
|
|
|
+ num_codebooks=num_codebooks,
|
|
|
)
|
|
|
- main_token_ids[0, -1] = end_token_id
|
|
|
-
|
|
|
- data = torch.cat((main_token_ids, data), dim=0)
|
|
|
- prompt = torch.cat((prompt, data), dim=1)
|
|
|
|
|
|
- return prompt
|
|
|
+ return encoded.to(device)
|
|
|
|
|
|
|
|
|
def load_model(checkpoint_path, device, precision, compile=False, is_agent=False):
|
|
|
model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
|
|
|
- checkpoint_path, load_weights=True
|
|
|
+ checkpoint_path, load_weights=True, is_agent=is_agent
|
|
|
)
|
|
|
|
|
|
model = model.to(device=device, dtype=precision)
|
|
|
@@ -729,11 +744,26 @@ def generate_long(
|
|
|
|
|
|
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
tokenizer = model.tokenizer
|
|
|
- im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
|
|
+ im_end_id = tokenizer.get_token_id("<|im_end|>")
|
|
|
|
|
|
encoded = []
|
|
|
texts = split_text(text, chunk_length) if iterative_prompt else [text]
|
|
|
- encoded_prompts = []
|
|
|
+ encoded_prompts = [
|
|
|
+ Conversation(
|
|
|
+ messages=[
|
|
|
+ Message(
|
|
|
+ role="system",
|
|
|
+ parts=[TextPart(text="Speak out the provided text.")],
|
|
|
+ cal_loss=False,
|
|
|
+ )
|
|
|
+ ]
|
|
|
+ )
|
|
|
+ .encode_for_inference(
|
|
|
+ tokenizer=tokenizer,
|
|
|
+ num_codebooks=model.config.num_codebooks,
|
|
|
+ )
|
|
|
+ .to(device)
|
|
|
+ ]
|
|
|
|
|
|
if use_prompt:
|
|
|
for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
|
|
|
@@ -812,7 +842,6 @@ def generate_long(
|
|
|
model=model,
|
|
|
prompt=cat_encoded,
|
|
|
max_new_tokens=max_new_tokens,
|
|
|
- im_end_id=im_end_id,
|
|
|
decode_one_token=decode_one_token,
|
|
|
temperature=temperature,
|
|
|
top_p=top_p,
|
|
|
@@ -842,12 +871,11 @@ def generate_long(
|
|
|
)
|
|
|
|
|
|
# Put the generated tokens
|
|
|
- # since there is <im_end> and <eos> tokens, we remove last 2 tokens
|
|
|
- codes = y[1:, prompt_length:-1].clone()
|
|
|
- codes = codes - 1
|
|
|
+ # since there is <im_end>, we remove last token
|
|
|
+ codes = y[1:, prompt_length + 1 :].clone()
|
|
|
assert (codes >= 0).all(), f"Negative code found"
|
|
|
|
|
|
- decoded = y[:, prompt_length:-1].clone()
|
|
|
+ decoded = y[:, prompt_length:].clone()
|
|
|
# But for global encoding, we should keep the <im_end> token
|
|
|
|
|
|
global_encoded.append(decoded)
|