|
|
@@ -27,7 +27,7 @@ from fish_speech.utils.braceexpand import braceexpand
|
|
|
|
|
|
log = RankedLogger(__name__, rank_zero_only=True)
|
|
|
|
|
|
-CODEBOOK_BOS_TOKEN_ID = 0
|
|
|
+CODEBOOK_PAD_TOKEN_ID = 0
|
|
|
CODEBOOK_EOS_TOKEN_ID = 1
|
|
|
|
|
|
|
|
|
@@ -476,11 +476,6 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
sentences = [f"[SPK: {speaker}]"] + sentences
|
|
|
|
|
|
final_text = "[INST] " + " ".join(sentences) + " [/INST]"
|
|
|
-
|
|
|
- for segment in semantics:
|
|
|
- for j in segment[0].values:
|
|
|
- final_text += f" <s:{int(j)}>"
|
|
|
-
|
|
|
encoded = self.tokenizer.encode(
|
|
|
final_text,
|
|
|
add_special_tokens=False,
|
|
|
@@ -488,14 +483,15 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
max_length=10**6,
|
|
|
)
|
|
|
semantic_length = sum([len(i[0].values) for i in semantics])
|
|
|
- prompt_length = len(encoded) - semantic_length
|
|
|
+ prompt_length = len(encoded)
|
|
|
+ num_codebooks = len(semantics[0])
|
|
|
|
|
|
bos_bias = 1 if add_bos else 0
|
|
|
|
|
|
# Pack the tokens and semantics (add <s> and </s> to semantic tokens)
|
|
|
tokens = (
|
|
|
encoded
|
|
|
- # + [self.tokenizer.pad_token_id] * semantic_length
|
|
|
+ + [self.tokenizer.pad_token_id] * (semantic_length + num_codebooks - 1)
|
|
|
+ [self.tokenizer.eos_token_id]
|
|
|
)
|
|
|
|
|
|
@@ -503,16 +499,18 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
tokens = [self.tokenizer.bos_token_id] + tokens
|
|
|
|
|
|
# Codebook bos/padding: 0, eos: 1
|
|
|
+ # Implement delay pattern
|
|
|
codes = [
|
|
|
- [CODEBOOK_BOS_TOKEN_ID] * (prompt_length + bos_bias)
|
|
|
- for _ in range(len(semantics[0]))
|
|
|
+ [CODEBOOK_PAD_TOKEN_ID] * (prompt_length + bos_bias + i)
|
|
|
+ for i in range(num_codebooks)
|
|
|
]
|
|
|
for segment in semantics:
|
|
|
for book_idx, book in enumerate(segment):
|
|
|
for j in book.values:
|
|
|
codes[book_idx].append(int(j) + 2)
|
|
|
|
|
|
- for book in codes:
|
|
|
+ for idx, book in enumerate(codes):
|
|
|
+ book.extend([CODEBOOK_PAD_TOKEN_ID] * (len(codes) - idx - 1))
|
|
|
book.append(CODEBOOK_EOS_TOKEN_ID)
|
|
|
|
|
|
tokens = [tokens] + codes
|
|
|
@@ -522,14 +520,15 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
|
|
|
# Mask out the <s> tokens for semantic, predict semantic tokens only
|
|
|
# Since we don't mask out the input tokens, the language modeling still works
|
|
|
- labels[1:, : (prompt_length + bos_bias)] = -100
|
|
|
+ # labels[1:, : (prompt_length + bos_bias)] = -100
|
|
|
+ labels[:, : (prompt_length + bos_bias)] = -100
|
|
|
|
|
|
tokens = tokens[:, :-1]
|
|
|
labels = labels[:, 1:]
|
|
|
|
|
|
# Verify the padding is correct, and the last token is eos
|
|
|
assert add_bos is False or tokens[0, 0] == self.tokenizer.bos_token_id
|
|
|
- assert (tokens[1:, : prompt_length + bos_bias] == CODEBOOK_BOS_TOKEN_ID).all()
|
|
|
+ assert (tokens[1:, : prompt_length + bos_bias] == CODEBOOK_PAD_TOKEN_ID).all()
|
|
|
assert labels[0, -1] == self.tokenizer.eos_token_id
|
|
|
assert (labels[1:, -1] == CODEBOOK_EOS_TOKEN_ID).all()
|
|
|
|
|
|
@@ -677,7 +676,7 @@ if __name__ == "__main__":
|
|
|
use_speaker=True,
|
|
|
interactive_prob=1.0,
|
|
|
phones_prob=1.0,
|
|
|
- use_negative_samples=True,
|
|
|
+ use_negative_samples=False,
|
|
|
)
|
|
|
|
|
|
# ds = AutoAugTextDataset(
|