|
|
@@ -14,7 +14,7 @@ from loguru import logger
|
|
|
from tqdm import tqdm
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
-from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID
|
|
|
+from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID, CODEBOOK_PAD_TOKEN_ID
|
|
|
from fish_speech.text.clean import clean_text
|
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
@@ -291,7 +291,7 @@ def encode_tokens(
|
|
|
):
|
|
|
string = clean_text(string)
|
|
|
|
|
|
- if speaker is not None:
|
|
|
+ if speaker is None:
|
|
|
speaker = "assistant"
|
|
|
|
|
|
string = (
|
|
|
@@ -309,7 +309,10 @@ def encode_tokens(
|
|
|
tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
|
|
|
|
|
|
# Codebooks
|
|
|
- zeros = torch.zeros((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
|
|
|
+ zeros = (
|
|
|
+ torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
|
|
|
+ * CODEBOOK_PAD_TOKEN_ID
|
|
|
+ )
|
|
|
prompt = torch.cat((tokens, zeros), dim=0)
|
|
|
|
|
|
if prompt_tokens is None:
|
|
|
@@ -331,13 +334,23 @@ def encode_tokens(
|
|
|
)
|
|
|
data = data[:num_codebooks]
|
|
|
|
|
|
+ # Add eos token for each codebook
|
|
|
+ data = torch.cat(
|
|
|
+ (
|
|
|
+ data,
|
|
|
+ torch.ones((data.size(0), 1), dtype=torch.int, device=device)
|
|
|
+ * CODEBOOK_EOS_TOKEN_ID,
|
|
|
+ ),
|
|
|
+ dim=1,
|
|
|
+ )
|
|
|
+
|
|
|
# Since 1.0, we use <|semantic|>
|
|
|
s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
|
|
|
- main_token_ids = torch.tensor(
|
|
|
- [[s0_token_id] * data.size(1)],
|
|
|
- dtype=torch.int,
|
|
|
- device=device,
|
|
|
+ end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
|
|
+ main_token_ids = (
|
|
|
+ torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
|
|
|
)
|
|
|
+ main_token_ids[0, -1] = end_token_id
|
|
|
|
|
|
data = torch.cat((main_token_ids, data), dim=0)
|
|
|
prompt = torch.cat((prompt, data), dim=1)
|
|
|
@@ -450,6 +463,20 @@ def generate_long(
|
|
|
use_prompt = prompt_text is not None and prompt_tokens is not None
|
|
|
encoded = []
|
|
|
texts = split_text(text, chunk_length) if iterative_prompt else [text]
|
|
|
+
|
|
|
+ if use_prompt:
|
|
|
+ encoded.append(
|
|
|
+ encode_tokens(
|
|
|
+ tokenizer,
|
|
|
+ prompt_text,
|
|
|
+ prompt_tokens=prompt_tokens,
|
|
|
+ bos=True,
|
|
|
+ device=device,
|
|
|
+ speaker=speaker,
|
|
|
+ num_codebooks=model.config.num_codebooks,
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
for idx, text in enumerate(texts):
|
|
|
encoded.append(
|
|
|
encode_tokens(
|
|
|
@@ -457,25 +484,12 @@ def generate_long(
|
|
|
string=text,
|
|
|
bos=idx == 0 and not use_prompt,
|
|
|
device=device,
|
|
|
- speaker=None,
|
|
|
+ speaker=speaker,
|
|
|
num_codebooks=model.config.num_codebooks,
|
|
|
)
|
|
|
)
|
|
|
logger.info(f"Encoded text: {text}")
|
|
|
|
|
|
- if use_prompt:
|
|
|
- encoded_prompt = encode_tokens(
|
|
|
- tokenizer,
|
|
|
- prompt_text,
|
|
|
- prompt_tokens=prompt_tokens,
|
|
|
- bos=True,
|
|
|
- device=device,
|
|
|
- speaker=speaker,
|
|
|
- num_codebooks=model.config.num_codebooks,
|
|
|
- )
|
|
|
-
|
|
|
- encoded[0] = torch.cat((encoded_prompt, encoded[0]), dim=1)
|
|
|
-
|
|
|
for sample_idx in range(num_samples):
|
|
|
torch.cuda.synchronize()
|
|
|
global_encoded = []
|