|
|
@@ -76,7 +76,6 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
|
|
|
tokenizer: AutoTokenizer = None,
|
|
|
use_speaker: bool | float = True,
|
|
|
causal: bool = True,
|
|
|
- use_negative_samples: bool = False,
|
|
|
num_codebooks: Optional[int] = None,
|
|
|
skip_text_prob: float = 0.0,
|
|
|
):
|
|
|
@@ -89,7 +88,6 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
|
|
|
tokenizer: tokenizer
|
|
|
use_speaker: include speaker information in the prompt
|
|
|
causal: use causal sampling when using local data, disable will lead to random sampling
|
|
|
- use_negative_samples: generate negative samples
|
|
|
num_codebooks: number of codebooks, if None, it will be automatically detected
|
|
|
skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
|
|
|
"""
|
|
|
@@ -105,7 +103,6 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
|
|
|
self.use_speaker = use_speaker
|
|
|
self.proto_files = proto_files
|
|
|
self.causal = causal
|
|
|
- self.use_negative_samples = use_negative_samples
|
|
|
self.num_codebooks = num_codebooks
|
|
|
self.skip_text_prob = skip_text_prob
|
|
|
|
|
|
@@ -242,7 +239,6 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
|
|
|
sentences=[text],
|
|
|
semantics=[sentence.semantics],
|
|
|
speaker=response.name if use_speaker else None,
|
|
|
- add_bos=idx == 0,
|
|
|
skip_text=random.random() < self.skip_text_prob,
|
|
|
)
|
|
|
|
|
|
@@ -256,7 +252,6 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
|
|
|
final_text,
|
|
|
semantics=final_semantic,
|
|
|
speaker=response.name if use_speaker else None,
|
|
|
- add_bos=True,
|
|
|
)
|
|
|
all_tokens.append(tokens)
|
|
|
all_labels.append(labels)
|
|
|
@@ -267,84 +262,15 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
|
|
|
# Verify that the length is correct
|
|
|
assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
|
|
|
|
|
|
- # Verify bos token
|
|
|
- assert tokens[0, 0] == self.tokenizer.bos_token_id
|
|
|
-
|
|
|
data = {"tokens": tokens, "labels": labels}
|
|
|
|
|
|
- if self.use_negative_samples:
|
|
|
- negative_samples = self.generate_negative_samples(all_tokens, all_labels)
|
|
|
- data.update(negative_samples)
|
|
|
-
|
|
|
return data
|
|
|
|
|
|
- def generate_negative_samples(self, all_tokens, all_labels):
|
|
|
- new_tokens, new_labels = [], []
|
|
|
-
|
|
|
- for tokens, labels in zip(all_tokens, all_labels):
|
|
|
- # If all codebooks are not -100, we find where it starts
|
|
|
- start = torch.where(labels[1:].sum(0) != -100 * (labels.size(0) - 1))[0][0]
|
|
|
- assert (labels[1:, start:] != -100).all() # This shouldn't happen
|
|
|
-
|
|
|
- mode = random.choice(["repeat", "lost", "noise"])
|
|
|
- begin = random.randint(start, labels.size(1) - 1)
|
|
|
- end = random.randint(begin, labels.size(1) - 1)
|
|
|
-
|
|
|
- if mode == "repeat":
|
|
|
- tokens = torch.cat(
|
|
|
- [
|
|
|
- tokens[:, :begin],
|
|
|
- tokens[:, begin:end],
|
|
|
- tokens[:, begin:end],
|
|
|
- tokens[:, end:],
|
|
|
- ],
|
|
|
- dim=1,
|
|
|
- )
|
|
|
- labels = torch.cat(
|
|
|
- [
|
|
|
- labels[:, :begin],
|
|
|
- labels[:, begin:end],
|
|
|
- labels[:, begin:end],
|
|
|
- labels[:, end:],
|
|
|
- ],
|
|
|
- dim=1,
|
|
|
- )
|
|
|
- elif mode == "lost":
|
|
|
- tokens = torch.cat([tokens[:, :begin], tokens[:, end:]], dim=1)
|
|
|
- labels = torch.cat([labels[:, :begin], labels[:, end:]], dim=1)
|
|
|
- elif mode == "noise":
|
|
|
- middle_tokens, middle_labels = (
|
|
|
- tokens[:, begin:end],
|
|
|
- labels[:, begin:end],
|
|
|
- )
|
|
|
- random_order0 = torch.randperm(middle_tokens.size(1))
|
|
|
- random_order1 = torch.randperm(middle_tokens.size(1))
|
|
|
- middle_tokens = middle_tokens[:, random_order0]
|
|
|
- middle_labels = middle_labels[:, random_order1]
|
|
|
- tokens = torch.cat(
|
|
|
- [tokens[:, :begin], middle_tokens, tokens[:, end:]], dim=1
|
|
|
- )
|
|
|
- labels = torch.cat(
|
|
|
- [labels[:, :begin], middle_labels, labels[:, end:]], dim=1
|
|
|
- )
|
|
|
-
|
|
|
- new_tokens.append(tokens)
|
|
|
- new_labels.append(labels)
|
|
|
-
|
|
|
- tokens = torch.cat(new_tokens, dim=1)
|
|
|
- labels = torch.cat(new_labels, dim=1)
|
|
|
-
|
|
|
- # Verify that the length is correct
|
|
|
- assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
|
|
|
-
|
|
|
- return {"negative_tokens": tokens, "negative_labels": labels}
|
|
|
-
|
|
|
def pack_sentences(
|
|
|
self,
|
|
|
sentences: list[str],
|
|
|
semantics: list,
|
|
|
speaker: Optional[str] = None,
|
|
|
- add_bos: bool = True,
|
|
|
skip_text: bool = False,
|
|
|
):
|
|
|
if speaker is None:
|
|
|
@@ -369,8 +295,6 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
|
|
|
len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
|
|
|
)
|
|
|
|
|
|
- bos_bias = 1 if add_bos else 0
|
|
|
-
|
|
|
# Pack the tokens and semantics (add <s> and </s> to semantic tokens)
|
|
|
tokens = (
|
|
|
encoded
|
|
|
@@ -378,14 +302,8 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
|
|
|
+ self.tokenizer.convert_tokens_to_ids(["<|im_end|>"])
|
|
|
)
|
|
|
|
|
|
- if add_bos:
|
|
|
- tokens = [self.tokenizer.bos_token_id] + tokens
|
|
|
-
|
|
|
# Codebook bos/padding: 0, eos: 1
|
|
|
- codes = [
|
|
|
- [CODEBOOK_PAD_TOKEN_ID] * (prompt_length + bos_bias)
|
|
|
- for _ in range(num_codebooks)
|
|
|
- ]
|
|
|
+ codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)]
|
|
|
for segment in semantics:
|
|
|
for book_idx, book in zip(range(num_codebooks), segment):
|
|
|
for j in book.values:
|
|
|
@@ -406,14 +324,13 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
|
|
|
|
|
|
# Mask out the <s> tokens for semantic, predict semantic tokens only
|
|
|
# Since we don't mask out the input tokens, the language modeling still works
|
|
|
- labels[1:, : (prompt_length + bos_bias)] = -100
|
|
|
+ labels[1:, :prompt_length] = -100
|
|
|
|
|
|
tokens = tokens[:, :-1]
|
|
|
labels = labels[:, 1:]
|
|
|
|
|
|
# Verify the padding is correct, and the last token is eos
|
|
|
- assert add_bos is False or tokens[0, 0] == self.tokenizer.bos_token_id
|
|
|
- assert (tokens[1:, : prompt_length + bos_bias] == CODEBOOK_PAD_TOKEN_ID).all()
|
|
|
+ assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all()
|
|
|
assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
|
|
|
|
|
|
return tokens, labels
|
|
|
@@ -564,12 +481,11 @@ class SemanticDataModule(LightningDataModule):
|
|
|
if __name__ == "__main__":
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
- ds = AutoAugTextDataset(
|
|
|
+ ds = AutoTextSemanticInstructionDataset(
|
|
|
["data/protos"],
|
|
|
tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
|
|
|
use_speaker=False,
|
|
|
interactive_prob=1.0,
|
|
|
- use_negative_samples=False,
|
|
|
skip_text_prob=0.5,
|
|
|
)
|
|
|
|