|
|
@@ -26,6 +26,9 @@ from fish_speech.utils.braceexpand import braceexpand
|
|
|
|
|
|
log = RankedLogger(__name__, rank_zero_only=True)
|
|
|
|
|
|
+CODEBOOK_BOS_TOKEN_ID = 0
|
|
|
+CODEBOOK_EOS_TOKEN_ID = 1
|
|
|
+
|
|
|
|
|
|
def split_by_rank_worker(files):
|
|
|
# We need to know the total number of devices
|
|
|
@@ -171,6 +174,12 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
1. Random concatenate multiple sentences from the same speaker to form a longer sentence
|
|
|
2. Automatically normalize the text
|
|
|
3. Mix text and phones
|
|
|
+
|
|
|
+ For interactive mode, we use the following format (multiple sequences):
|
|
|
+ <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
|
|
|
+
|
|
|
+ For non-interactive mode, we use the following format (one long sequence):
|
|
|
+ <s> [INST] text [/INST] ... </s>
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
@@ -179,6 +188,7 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
seed: int = 42,
|
|
|
phones_prob: float = 0.3,
|
|
|
repetition_prob: float = 0.0,
|
|
|
+ interactive_prob: float = 0.5,
|
|
|
max_length: int = 1024,
|
|
|
tokenizer: AutoTokenizer = None,
|
|
|
use_speaker: bool = True,
|
|
|
@@ -189,6 +199,7 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
seed: random seed
|
|
|
phones_prob: probability to use phones
|
|
|
repetition_prob: probability to repeat the same sentence
|
|
|
+ interactive_prob: probability to use interactive mode
|
|
|
max_length: max length of the text
|
|
|
tokenizer: tokenizer
|
|
|
"""
|
|
|
@@ -200,6 +211,7 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
self.max_length = max_length
|
|
|
self.tokenizer = tokenizer
|
|
|
self.repetition_prob = repetition_prob
|
|
|
+ self.interactive_prob = interactive_prob
|
|
|
self.use_speaker = use_speaker
|
|
|
|
|
|
# Read all lines, and group by speaker
|
|
|
@@ -258,6 +270,10 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
return None
|
|
|
|
|
|
samples = list(response.samples)
|
|
|
+ idx = 0
|
|
|
+ use_interactive = random.random() < self.interactive_prob
|
|
|
+
|
|
|
+ all_tokens, all_labels = [], []
|
|
|
while remaining_tokens > 0 and len(samples) > 0:
|
|
|
if random.random() < self.repetition_prob:
|
|
|
# Repeat the same sentence
|
|
|
@@ -269,70 +285,107 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
sentence.text, sentence.phones, mode=mode
|
|
|
)
|
|
|
remaining_tokens -= length + len(sentence.semantics[0].values)
|
|
|
- final_text.append(text)
|
|
|
- final_semantic.append(sentence.semantics)
|
|
|
|
|
|
- if self.use_speaker:
|
|
|
- final_text = [f"[SPK: {response.name}]"] + final_text
|
|
|
+ if use_interactive is False:
|
|
|
+ final_text.append(text)
|
|
|
+ final_semantic.append(sentence.semantics)
|
|
|
+ else:
|
|
|
+ # For interactive mode, we only apply speaker for the first sentence
|
|
|
+ # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
|
|
|
+ tokens, labels = self.pack_sentences(
|
|
|
+ sentences=[text],
|
|
|
+ semantics=[sentence.semantics],
|
|
|
+ speaker=response.name if (self.use_speaker and idx == 0) else None,
|
|
|
+ add_bos=idx == 0,
|
|
|
+ )
|
|
|
+
|
|
|
+ all_tokens.append(tokens)
|
|
|
+ all_labels.append(labels)
|
|
|
+
|
|
|
+ idx += 1
|
|
|
+
|
|
|
+ if use_interactive is False:
|
|
|
+ tokens, labels = self.pack_sentences(
|
|
|
+ final_text,
|
|
|
+ semantics=final_semantic,
|
|
|
+ speaker=None if self.use_speaker else sentence.speaker,
|
|
|
+ add_bos=True,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ print(all_tokens[0].shape)
|
|
|
+ tokens = torch.cat(all_tokens, dim=1)
|
|
|
+ labels = torch.cat(all_labels, dim=1)
|
|
|
+
|
|
|
+ # Verify that the length is correct
|
|
|
+ assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
|
|
|
+
|
|
|
+ # Verify only one <s> token
|
|
|
+ assert (tokens[:, 0] == self.tokenizer.bos_token_id).sum() == 1
|
|
|
+
|
|
|
+ return {"tokens": tokens, "labels": labels}
|
|
|
|
|
|
- final_text = "[INST] " + " ".join(final_text) + " [/INST]"
|
|
|
+ def pack_sentences(
|
|
|
+ self,
|
|
|
+ sentences: list[str],
|
|
|
+ semantics=list,
|
|
|
+ speaker: Optional[str] = None,
|
|
|
+ add_bos: bool = True,
|
|
|
+ ):
|
|
|
+ if speaker is not None:
|
|
|
+ sentences = [f"[SPK: {speaker}]"] + sentences
|
|
|
+
|
|
|
+ final_text = "[INST] " + " ".join(sentences) + " [/INST]"
|
|
|
encoded = self.tokenizer.encode(
|
|
|
final_text,
|
|
|
add_special_tokens=False,
|
|
|
truncation=False,
|
|
|
max_length=10**6,
|
|
|
)
|
|
|
- semantic_length = sum([len(i[0].values) for i in final_semantic])
|
|
|
+ semantic_length = sum([len(i[0].values) for i in semantics])
|
|
|
+ bos_bias = 1 if add_bos else 0
|
|
|
+
|
|
|
+ # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
|
|
|
+ tokens = (
|
|
|
+ encoded
|
|
|
+ + [self.tokenizer.pad_token_id] * semantic_length
|
|
|
+ + [self.tokenizer.eos_token_id]
|
|
|
+ )
|
|
|
|
|
|
- # Single codebook
|
|
|
- if len(final_semantic[0]) == 1:
|
|
|
- semantic_tokens = [f"<s:{j}>" for i in final_semantic for j in i[0].values]
|
|
|
- tokenized = self.tokenizer.encode(
|
|
|
- f" ".join(semantic_tokens),
|
|
|
- add_special_tokens=False,
|
|
|
- truncation=False,
|
|
|
- max_length=10**6,
|
|
|
- )
|
|
|
+ if add_bos:
|
|
|
+ tokens = [self.tokenizer.bos_token_id] + tokens
|
|
|
|
|
|
- # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
|
|
|
- tokens = (
|
|
|
- [self.tokenizer.bos_token_id]
|
|
|
- + encoded
|
|
|
- + tokenized
|
|
|
- + [self.tokenizer.eos_token_id]
|
|
|
- )
|
|
|
- tokens = torch.tensor([tokens], dtype=torch.long)
|
|
|
- labels = tokens.clone()
|
|
|
- labels[0, : len(encoded) + 1] = -100 # Mask out the <s> and query tokens
|
|
|
- else:
|
|
|
- # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
|
|
|
- tokens = (
|
|
|
- [self.tokenizer.bos_token_id]
|
|
|
- + encoded
|
|
|
- + [self.tokenizer.pad_token_id] * semantic_length
|
|
|
- + [self.tokenizer.eos_token_id]
|
|
|
- )
|
|
|
- codes = [[0] * (len(encoded) + 1) for _ in range(len(final_semantic[0]))]
|
|
|
- for segment in final_semantic:
|
|
|
- for book_idx, book in enumerate(segment):
|
|
|
- for j in book.values:
|
|
|
- codes[book_idx].append(int(j) + 2)
|
|
|
+ # Codebook bos/padding: 0, eos: 1
|
|
|
+ codes = [
|
|
|
+ [CODEBOOK_BOS_TOKEN_ID] * (len(encoded) + bos_bias)
|
|
|
+ for _ in range(len(semantics[0]))
|
|
|
+ ]
|
|
|
+ for segment in semantics:
|
|
|
+ for book_idx, book in enumerate(segment):
|
|
|
+ for j in book.values:
|
|
|
+ codes[book_idx].append(int(j) + 2)
|
|
|
|
|
|
- for book in codes:
|
|
|
- book.append(1)
|
|
|
+ for book in codes:
|
|
|
+ book.append(CODEBOOK_EOS_TOKEN_ID)
|
|
|
|
|
|
- tokens = [tokens] + codes
|
|
|
+ tokens = [tokens] + codes
|
|
|
|
|
|
- tokens = torch.tensor(tokens, dtype=torch.long)
|
|
|
- labels = tokens.clone()
|
|
|
- labels[
|
|
|
- 1:, : len(encoded) + 1
|
|
|
- ] = -100 # Mask out the <s> tokens for semantic
|
|
|
+ tokens = torch.tensor(tokens, dtype=torch.long)
|
|
|
+ labels = tokens.clone()
|
|
|
|
|
|
- return {
|
|
|
- "tokens": tokens[:, :-1],
|
|
|
- "labels": labels[:, 1:],
|
|
|
- }
|
|
|
+ # Mask out the <s> tokens for semantic, predict semantic tokens only
|
|
|
+ # Since we don't mask out the input tokens, the language modeling still works
|
|
|
+ labels[1:, : (len(encoded) + bos_bias)] = -100
|
|
|
+
|
|
|
+ tokens = tokens[:, :-1]
|
|
|
+ labels = labels[:, 1:]
|
|
|
+
|
|
|
+ # Verify the padding is correct, and the last token is eos
|
|
|
+ assert add_bos is False or tokens[0, 0] == self.tokenizer.bos_token_id
|
|
|
+ assert (tokens[1:, : len(encoded) + bos_bias] == CODEBOOK_BOS_TOKEN_ID).all()
|
|
|
+ assert labels[0, -1] == self.tokenizer.eos_token_id
|
|
|
+ assert (labels[1:, -1] == CODEBOOK_EOS_TOKEN_ID).all()
|
|
|
+
|
|
|
+ return tokens, labels
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
@@ -346,18 +399,20 @@ class TextDataCollator:
|
|
|
_tokens = example["tokens"][:, : self.max_length]
|
|
|
_labels = example["labels"][:, : self.max_length]
|
|
|
_attention_mask = torch.ones((self.max_length,), dtype=torch.bool)
|
|
|
- _attention_mask[: _tokens.size(1)] = False
|
|
|
+ tokens_length = _tokens.size(1)
|
|
|
+ _attention_mask[:tokens_length] = False
|
|
|
|
|
|
- assert _tokens.size(1) == _labels.size(
|
|
|
+ assert tokens_length == _labels.size(
|
|
|
1
|
|
|
- ), f"{_tokens.size(1)} != {_labels.size(1)}"
|
|
|
+ ), f"{tokens_length} != {_labels.size(1)}"
|
|
|
|
|
|
- if _tokens.size(1) < self.max_length:
|
|
|
+ if tokens_length < self.max_length:
|
|
|
_tokens = F.pad(
|
|
|
_tokens,
|
|
|
- (0, self.max_length - _tokens.size(1)),
|
|
|
+ (0, self.max_length - tokens_length),
|
|
|
value=self.tokenizer.eos_token_id,
|
|
|
)
|
|
|
+ _tokens[1:, tokens_length:] = CODEBOOK_EOS_TOKEN_ID
|
|
|
_labels = F.pad(
|
|
|
_labels, (0, self.max_length - _labels.size(1)), value=-100
|
|
|
)
|
|
|
@@ -444,20 +499,14 @@ class TextDataModule(LightningDataModule):
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
- import json
|
|
|
-
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
ds = AutoAugTextDataset(
|
|
|
tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
|
|
|
use_speaker=True,
|
|
|
+ interactive_prob=1.0,
|
|
|
)
|
|
|
|
|
|
- # ds = StreamTextDataset(
|
|
|
- # prefix="en/",
|
|
|
- # tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
|
|
|
- # )
|
|
|
-
|
|
|
dm = TextDataModule(
|
|
|
train_dataset=ds,
|
|
|
val_dataset=ds,
|