|
@@ -14,6 +14,7 @@ from loguru import logger
|
|
|
from tqdm import tqdm
|
|
from tqdm import tqdm
|
|
|
from transformers import AutoTokenizer
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
|
|
|
|
+from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID
|
|
|
from fish_speech.text.clean import clean_text
|
|
from fish_speech.text.clean import clean_text
|
|
|
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
@@ -147,9 +148,9 @@ def decode_one_token_naive(
|
|
|
codebooks.append(
|
|
codebooks.append(
|
|
|
sample(
|
|
sample(
|
|
|
x.codebook_logits[:, :, i],
|
|
x.codebook_logits[:, :, i],
|
|
|
- previous_tokens=previous_tokens[i + 1]
|
|
|
|
|
- if previous_tokens is not None
|
|
|
|
|
- else None,
|
|
|
|
|
|
|
+ previous_tokens=(
|
|
|
|
|
+ previous_tokens[i + 1] if previous_tokens is not None else None
|
|
|
|
|
+ ),
|
|
|
**sampling_kwargs,
|
|
**sampling_kwargs,
|
|
|
)[0]
|
|
)[0]
|
|
|
)
|
|
)
|
|
@@ -163,6 +164,7 @@ def decode_n_tokens(
|
|
|
input_pos: torch.Tensor,
|
|
input_pos: torch.Tensor,
|
|
|
num_new_tokens: int,
|
|
num_new_tokens: int,
|
|
|
eos_token_id: int = 2,
|
|
eos_token_id: int = 2,
|
|
|
|
|
+ im_end_id: int = 4,
|
|
|
decode_one_token=decode_one_token_naive,
|
|
decode_one_token=decode_one_token_naive,
|
|
|
**sampling_kwargs,
|
|
**sampling_kwargs,
|
|
|
):
|
|
):
|
|
@@ -197,8 +199,11 @@ def decode_n_tokens(
|
|
|
model.config.num_codebooks + 1, -1
|
|
model.config.num_codebooks + 1, -1
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- # TODO: use tokenizer's eos
|
|
|
|
|
- if cur_token[0, 0, -1] == eos_token_id or (cur_token[0, 1:, -1] == 1).any():
|
|
|
|
|
|
|
+ if (
|
|
|
|
|
+ cur_token[0, 0, -1] == eos_token_id
|
|
|
|
|
+ or cur_token[0, 0, -1] == im_end_id
|
|
|
|
|
+ or (cur_token[0, 1:, -1] == CODEBOOK_EOS_TOKEN_ID).any()
|
|
|
|
|
+ ):
|
|
|
break
|
|
break
|
|
|
|
|
|
|
|
return previous_tokens[:, : i + 1]
|
|
return previous_tokens[:, : i + 1]
|
|
@@ -212,6 +217,7 @@ def generate(
|
|
|
prompt: torch.Tensor,
|
|
prompt: torch.Tensor,
|
|
|
max_new_tokens: int,
|
|
max_new_tokens: int,
|
|
|
eos_token_id: int = 2,
|
|
eos_token_id: int = 2,
|
|
|
|
|
+ im_end_id: int = 4,
|
|
|
decode_one_token=decode_one_token_naive,
|
|
decode_one_token=decode_one_token_naive,
|
|
|
precision: torch.dtype = torch.bfloat16,
|
|
precision: torch.dtype = torch.bfloat16,
|
|
|
**sampling_kwargs,
|
|
**sampling_kwargs,
|
|
@@ -256,6 +262,7 @@ def generate(
|
|
|
input_pos,
|
|
input_pos,
|
|
|
max_new_tokens - 1,
|
|
max_new_tokens - 1,
|
|
|
eos_token_id=eos_token_id,
|
|
eos_token_id=eos_token_id,
|
|
|
|
|
+ im_end_id=im_end_id,
|
|
|
decode_one_token=decode_one_token,
|
|
decode_one_token=decode_one_token,
|
|
|
**sampling_kwargs,
|
|
**sampling_kwargs,
|
|
|
)
|
|
)
|
|
@@ -283,9 +290,12 @@ def encode_tokens(
|
|
|
string = (
|
|
string = (
|
|
|
f"<|im_start|>user<|im_sep|>{string}<|im_end|><|im_start|>assistant<|im_sep|>"
|
|
f"<|im_start|>user<|im_sep|>{string}<|im_end|><|im_start|>assistant<|im_sep|>"
|
|
|
)
|
|
)
|
|
|
|
|
+ if bos:
|
|
|
|
|
+ string = f"<|begin_of_sequence|>{string}"
|
|
|
|
|
+
|
|
|
new_tokens = tokenizer.encode(
|
|
new_tokens = tokenizer.encode(
|
|
|
string,
|
|
string,
|
|
|
- add_special_tokens=bos,
|
|
|
|
|
|
|
+ add_special_tokens=False,
|
|
|
max_length=10**6,
|
|
max_length=10**6,
|
|
|
truncation=False,
|
|
truncation=False,
|
|
|
)
|
|
)
|
|
@@ -392,7 +402,11 @@ def split_text(text, min_length):
|
|
|
|
|
|
|
|
|
|
|
|
|
@click.command()
|
|
@click.command()
|
|
|
-@click.option("--text", type=str, default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.")
|
|
|
|
|
|
|
+@click.option(
|
|
|
|
|
+ "--text",
|
|
|
|
|
+ type=str,
|
|
|
|
|
+ default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
|
|
|
|
|
+)
|
|
|
@click.option("--prompt-text", type=str, default=None)
|
|
@click.option("--prompt-text", type=str, default=None)
|
|
|
@click.option(
|
|
@click.option(
|
|
|
"--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
|
|
"--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
|
|
@@ -457,6 +471,8 @@ def main(
|
|
|
else None
|
|
else None
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
+ im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
|
|
|
|
+
|
|
|
use_prompt = prompt_text is not None and prompt_tokens is not None
|
|
use_prompt = prompt_text is not None and prompt_tokens is not None
|
|
|
encoded = []
|
|
encoded = []
|
|
|
texts = split_text(text, chunk_length) if iterative_prompt else [text]
|
|
texts = split_text(text, chunk_length) if iterative_prompt else [text]
|
|
@@ -539,6 +555,7 @@ def main(
|
|
|
prompt=cat_encoded,
|
|
prompt=cat_encoded,
|
|
|
max_new_tokens=max_new_tokens,
|
|
max_new_tokens=max_new_tokens,
|
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
|
|
|
+ im_end_id=im_end_id,
|
|
|
decode_one_token=decode_one_token,
|
|
decode_one_token=decode_one_token,
|
|
|
precision=precision,
|
|
precision=precision,
|
|
|
temperature=temperature,
|
|
temperature=temperature,
|
|
@@ -575,8 +592,15 @@ def main(
|
|
|
logger.warning(f"Negative code found: {codes}, retrying ...")
|
|
logger.warning(f"Negative code found: {codes}, retrying ...")
|
|
|
continue
|
|
continue
|
|
|
|
|
|
|
|
|
|
+ decoded = y[:, prompt_length:-1].clone()
|
|
|
|
|
+ if decoded[0, -1] != im_end_id: # <im_end>
|
|
|
|
|
+ val = [[im_end_id]] + [[CODEBOOK_EOS_TOKEN_ID]] * (decoded.size(0) - 1)
|
|
|
|
|
+ decoded = torch.cat(
|
|
|
|
|
+ (decoded, torch.tensor(val, device=device, dtype=torch.int)), dim=1
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
# But for global encoding, we should keep the <im_end> token
|
|
# But for global encoding, we should keep the <im_end> token
|
|
|
- global_encoded.append(y[:, prompt_length:-1].clone())
|
|
|
|
|
|
|
+ global_encoded.append(decoded)
|
|
|
all_codes.append(codes)
|
|
all_codes.append(codes)
|
|
|
seg_idx += 1
|
|
seg_idx += 1
|
|
|
|
|
|