|
|
@@ -404,6 +404,11 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
sentences = [f"[SPK: {speaker}]"] + sentences
|
|
|
|
|
|
final_text = "[INST] " + " ".join(sentences) + " [/INST]"
|
|
|
+
|
|
|
+ for segment in semantics:
|
|
|
+ for j in segment[0].values:
|
|
|
+ final_text += f" <s:{int(j)}>"
|
|
|
+
|
|
|
encoded = self.tokenizer.encode(
|
|
|
final_text,
|
|
|
add_special_tokens=False,
|
|
|
@@ -411,12 +416,14 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
max_length=10**6,
|
|
|
)
|
|
|
semantic_length = sum([len(i[0].values) for i in semantics])
|
|
|
+ prompt_length = len(encoded) - semantic_length
|
|
|
+
|
|
|
bos_bias = 1 if add_bos else 0
|
|
|
|
|
|
# Pack the tokens and semantics (add <s> and </s> to semantic tokens)
|
|
|
tokens = (
|
|
|
encoded
|
|
|
- + [self.tokenizer.pad_token_id] * semantic_length
|
|
|
+ # + [self.tokenizer.pad_token_id] * semantic_length
|
|
|
+ [self.tokenizer.eos_token_id]
|
|
|
)
|
|
|
|
|
|
@@ -425,7 +432,7 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
|
|
|
# Codebook bos/padding: 0, eos: 1
|
|
|
codes = [
|
|
|
- [CODEBOOK_BOS_TOKEN_ID] * (len(encoded) + bos_bias)
|
|
|
+ [CODEBOOK_BOS_TOKEN_ID] * (prompt_length + bos_bias)
|
|
|
for _ in range(len(semantics[0]))
|
|
|
]
|
|
|
for segment in semantics:
|
|
|
@@ -443,14 +450,14 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
|
|
|
# Mask out the <s> tokens for semantic, predict semantic tokens only
|
|
|
# Since we don't mask out the input tokens, the language modeling still works
|
|
|
- labels[1:, : (len(encoded) + bos_bias)] = -100
|
|
|
+ labels[1:, : (prompt_length + bos_bias)] = -100
|
|
|
|
|
|
tokens = tokens[:, :-1]
|
|
|
labels = labels[:, 1:]
|
|
|
|
|
|
# Verify the padding is correct, and the last token is eos
|
|
|
assert add_bos is False or tokens[0, 0] == self.tokenizer.bos_token_id
|
|
|
- assert (tokens[1:, : len(encoded) + bos_bias] == CODEBOOK_BOS_TOKEN_ID).all()
|
|
|
+ assert (tokens[1:, : prompt_length + bos_bias] == CODEBOOK_BOS_TOKEN_ID).all()
|
|
|
assert labels[0, -1] == self.tokenizer.eos_token_id
|
|
|
assert (labels[1:, -1] == CODEBOOK_EOS_TOKEN_ID).all()
|
|
|
|