|
|
@@ -14,12 +14,18 @@ from huggingface_hub import HfApi
|
|
|
from lightning import LightningDataModule
|
|
|
from torch.distributed import get_rank, get_world_size, is_initialized
|
|
|
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
|
|
|
-from transformers import AutoTokenizer
|
|
|
|
|
|
-from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
|
|
|
+from fish_speech.conversation import (
|
|
|
+ CODEBOOK_PAD_TOKEN_ID,
|
|
|
+ Conversation,
|
|
|
+ Message,
|
|
|
+ TextPart,
|
|
|
+ VQPart,
|
|
|
+)
|
|
|
from fish_speech.datasets.protos.text_data_pb2 import SampledData
|
|
|
from fish_speech.datasets.protos.text_data_stream import read_pb_stream
|
|
|
from fish_speech.text.clean import clean_text
|
|
|
+from fish_speech.tokenizer import FishTokenizer
|
|
|
from fish_speech.utils import RankedLogger
|
|
|
from fish_speech.utils.braceexpand import braceexpand
|
|
|
|
|
|
@@ -73,7 +79,7 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
|
|
|
seed: int = 42,
|
|
|
interactive_prob: float = 0.5,
|
|
|
max_length: int = 1024,
|
|
|
- tokenizer: AutoTokenizer = None,
|
|
|
+ tokenizer: FishTokenizer = None,
|
|
|
use_speaker: bool | float = True,
|
|
|
causal: bool = True,
|
|
|
num_codebooks: Optional[int] = None,
|
|
|
@@ -106,9 +112,12 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
|
|
|
self.num_codebooks = num_codebooks
|
|
|
self.skip_text_prob = skip_text_prob
|
|
|
|
|
|
- self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
|
|
|
self.groups = None
|
|
|
|
|
|
+ def __iter__(self):
|
|
|
+ while True:
|
|
|
+ yield self.augment()
|
|
|
+
|
|
|
def init_mock_data_server(self):
|
|
|
if self.groups is not None:
|
|
|
return
|
|
|
@@ -148,20 +157,6 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
|
|
|
Random(self.seed).shuffle(self.groups)
|
|
|
self.group_weights = [len(i.sentences) for i in self.groups]
|
|
|
|
|
|
- def __iter__(self):
|
|
|
- while True:
|
|
|
- yield self.augment()
|
|
|
-
|
|
|
- def tokenize_sentence(self, sentence: str):
|
|
|
- sentence = clean_text(sentence)
|
|
|
- tokens = self.tokenizer.encode(
|
|
|
- f"{sentence}",
|
|
|
- max_length=10**6,
|
|
|
- add_special_tokens=False,
|
|
|
- truncation=False,
|
|
|
- )
|
|
|
- return sentence, len(tokens)
|
|
|
-
|
|
|
def sample_data(self):
|
|
|
if self.groups is None:
|
|
|
self.init_mock_data_server()
|
|
|
@@ -190,155 +185,119 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
|
|
|
samples=samples,
|
|
|
)
|
|
|
|
|
|
- def augment(self):
|
|
|
- final_text, final_semantic = [], []
|
|
|
- response = self.sample_data()
|
|
|
- if len(response.samples) == 0:
|
|
|
- # Invalid group
|
|
|
- return None
|
|
|
-
|
|
|
- samples = list(response.samples)
|
|
|
- idx = 0
|
|
|
- use_interactive = random.random() < self.interactive_prob
|
|
|
-
|
|
|
- if use_interactive is False:
|
|
|
- # Random sample based on speaker using a truncated normal distribution
|
|
|
- a = torch.tensor([0], dtype=torch.float32)
|
|
|
- torch.nn.init.trunc_normal_(
|
|
|
- a,
|
|
|
- mean=self.max_length // 2,
|
|
|
- std=self.max_length // 4,
|
|
|
- a=10,
|
|
|
- b=self.max_length,
|
|
|
- )
|
|
|
- remaining_tokens = a.long().item() - 4
|
|
|
- else:
|
|
|
- remaining_tokens = self.max_length
|
|
|
-
|
|
|
- # Use speaker
|
|
|
- if isinstance(self.use_speaker, float):
|
|
|
- use_speaker = random.random() < self.use_speaker
|
|
|
- else:
|
|
|
- use_speaker = self.use_speaker
|
|
|
-
|
|
|
- all_tokens, all_labels = [], []
|
|
|
- while remaining_tokens > 0 and len(samples) > 0:
|
|
|
- sentence = samples.pop(0)
|
|
|
-
|
|
|
- text = random.choice(sentence.texts)
|
|
|
- text, length = self.tokenize_sentence(text)
|
|
|
- remaining_tokens -= length + len(sentence.semantics[0].values)
|
|
|
-
|
|
|
- if use_interactive is False:
|
|
|
- final_text.append(text)
|
|
|
- final_semantic.append(sentence.semantics)
|
|
|
- else:
|
|
|
- # For interactive mode, we only apply speaker for the first sentence
|
|
|
- # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
|
|
|
- tokens, labels = self.pack_sentences(
|
|
|
- sentences=[text],
|
|
|
- semantics=[sentence.semantics],
|
|
|
- speaker=response.name if use_speaker else None,
|
|
|
- skip_text=random.random() < self.skip_text_prob,
|
|
|
- )
|
|
|
-
|
|
|
- all_tokens.append(tokens)
|
|
|
- all_labels.append(labels)
|
|
|
-
|
|
|
- idx += 1
|
|
|
-
|
|
|
- if use_interactive is False:
|
|
|
- tokens, labels = self.pack_sentences(
|
|
|
- final_text,
|
|
|
- semantics=final_semantic,
|
|
|
- speaker=response.name if use_speaker else None,
|
|
|
- )
|
|
|
- all_tokens.append(tokens)
|
|
|
- all_labels.append(labels)
|
|
|
-
|
|
|
- tokens = torch.cat(all_tokens, dim=1)
|
|
|
- labels = torch.cat(all_labels, dim=1)
|
|
|
-
|
|
|
- # Verify that the length is correct
|
|
|
- assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
|
|
|
-
|
|
|
- data = {"tokens": tokens, "labels": labels}
|
|
|
-
|
|
|
- return data
|
|
|
-
|
|
|
def pack_sentences(
|
|
|
self,
|
|
|
sentences: list[str],
|
|
|
semantics: list,
|
|
|
- speaker: Optional[str] = None,
|
|
|
+ # speaker: Optional[str] = None,
|
|
|
skip_text: bool = False,
|
|
|
):
|
|
|
- if speaker is None:
|
|
|
- speaker = "assistant"
|
|
|
+ # if speaker is None:
|
|
|
+ # speaker = "assistant"
|
|
|
+
|
|
|
+ messages = [
|
|
|
+ Message(
|
|
|
+ role="system",
|
|
|
+ parts=[TextPart(text="Speak out the provided text.")],
|
|
|
+ # add_im_end=False,
|
|
|
+ # cal_loss=True,
|
|
|
+ )
|
|
|
+ ]
|
|
|
|
|
|
cated_sentences = " ".join(sentences)
|
|
|
if skip_text:
|
|
|
cated_sentences = "<|skip_text|>"
|
|
|
|
|
|
- final_text = "<|im_start|>user\n" + cated_sentences + "<|im_end|>"
|
|
|
- final_text = final_text + f"<|im_start|>{speaker}\n"
|
|
|
+ messages.append(
|
|
|
+ Message(
|
|
|
+ role="user",
|
|
|
+ parts=[TextPart(text=cated_sentences)],
|
|
|
+ # cal_loss=True,
|
|
|
+ )
|
|
|
+ )
|
|
|
|
|
|
- encoded = self.tokenizer.encode(
|
|
|
- final_text,
|
|
|
- add_special_tokens=False,
|
|
|
- truncation=False,
|
|
|
- max_length=10**6,
|
|
|
+ vq_codes = [x.values for x in semantics[0]]
|
|
|
+ vq_codes_tensor = torch.tensor(vq_codes).to(torch.int32)
|
|
|
+ vqpart = VQPart(codes=vq_codes_tensor)
|
|
|
+ messages.append(
|
|
|
+ Message(
|
|
|
+ role="assistant",
|
|
|
+ parts=[TextPart(text="<|voice|>"), vqpart],
|
|
|
+ cal_loss=True,
|
|
|
+ )
|
|
|
)
|
|
|
- semantic_length = sum([len(i[0].values) for i in semantics])
|
|
|
- prompt_length = len(encoded)
|
|
|
+
|
|
|
num_codebooks = (
|
|
|
len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
|
|
|
)
|
|
|
|
|
|
- # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
|
|
|
- tokens = (
|
|
|
- encoded
|
|
|
- + [self.semantic_token_id] * semantic_length
|
|
|
- + self.tokenizer.convert_tokens_to_ids(["<|im_end|>"])
|
|
|
+ conversation = Conversation(messages=messages)
|
|
|
+ # conversation.visualize(tokenizer=self.tokenizer)
|
|
|
+ encoded = conversation.encode(
|
|
|
+ tokenizer=self.tokenizer,
|
|
|
)
|
|
|
|
|
|
- # Codebook bos/padding: 0, eos: 1
|
|
|
- codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)]
|
|
|
- for segment in semantics:
|
|
|
- for book_idx, book in zip(range(num_codebooks), segment):
|
|
|
- for j in book.values:
|
|
|
- codes[book_idx].append(int(j) + 1)
|
|
|
+ tokens_raw = encoded.tokens
|
|
|
+ tokens = torch.zeros((num_codebooks + 1, len(tokens_raw)), dtype=torch.int)
|
|
|
+ tokens[0] = tokens_raw
|
|
|
|
|
|
- for book in codes:
|
|
|
- book.extend([CODEBOOK_PAD_TOKEN_ID] * 1)
|
|
|
+ vq_parts = encoded.vq_parts
|
|
|
+ vq_parts = [part.to(tokens.device) for part in vq_parts]
|
|
|
+ vq_parts = torch.cat(vq_parts, dim=1)
|
|
|
+ tokens[1:, encoded.vq_mask_tokens] = vq_parts
|
|
|
|
|
|
- tokens = [tokens] + codes
|
|
|
+ labels_raw = encoded.labels
|
|
|
+ labels = torch.full((num_codebooks + 1, len(labels_raw)), -100, dtype=torch.int)
|
|
|
+ labels[0, :] = labels_raw
|
|
|
+ labels[1:, encoded.vq_mask_labels] = vq_parts
|
|
|
+ labels[1:, -1:] = CODEBOOK_PAD_TOKEN_ID
|
|
|
|
|
|
- tokens = torch.tensor(tokens, dtype=torch.long)
|
|
|
- labels = tokens.clone()
|
|
|
-
|
|
|
- if skip_text:
|
|
|
- # If text is not provided, the sentence is used for condition only, all labels are -100
|
|
|
- torch.fill_(labels, -100)
|
|
|
- return tokens, labels
|
|
|
-
|
|
|
- # Mask out the <s> tokens for semantic, predict semantic tokens only
|
|
|
- # Since we don't mask out the input tokens, the language modeling still works
|
|
|
- labels[1:, :prompt_length] = -100
|
|
|
-
|
|
|
- tokens = tokens[:, :-1]
|
|
|
- labels = labels[:, 1:]
|
|
|
+ tokens = tokens.long()
|
|
|
+ labels = labels.long()
|
|
|
|
|
|
# Verify the padding is correct, and the last token is eos
|
|
|
- assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all()
|
|
|
+ assert (tokens[1:, ~(encoded.vq_mask_tokens)] == CODEBOOK_PAD_TOKEN_ID).all()
|
|
|
assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
|
|
|
|
|
|
return tokens, labels
|
|
|
|
|
|
+ def augment(self):
|
|
|
+ response = self.sample_data()
|
|
|
+ if len(response.samples) == 0:
|
|
|
+ # Invalid group
|
|
|
+ return None
|
|
|
+
|
|
|
+ samples = list(response.samples)
|
|
|
+ all_tokens, all_labels = [], []
|
|
|
+
|
|
|
+ while len(samples) > 0:
|
|
|
+ sentence = samples.pop(0)
|
|
|
+ text = clean_text(random.choice(sentence.texts))
|
|
|
+
|
|
|
+ tokens, labels = self.pack_sentences(
|
|
|
+ sentences=[text],
|
|
|
+ semantics=[sentence.semantics],
|
|
|
+ # speaker=response.name if use_speaker else None,
|
|
|
+ skip_text=random.random() < self.skip_text_prob,
|
|
|
+ )
|
|
|
+
|
|
|
+ all_tokens.append(tokens)
|
|
|
+ all_labels.append(labels)
|
|
|
+
|
|
|
+ tokens = torch.cat(all_tokens, dim=1)
|
|
|
+ labels = torch.cat(all_labels, dim=1)
|
|
|
+
|
|
|
+ # Verify that the length is correct
|
|
|
+ assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
|
|
|
+
|
|
|
+ data = {"tokens": tokens, "labels": labels}
|
|
|
+
|
|
|
+ return data
|
|
|
+
|
|
|
|
|
|
@dataclass
|
|
|
class TextDataCollator:
|
|
|
- tokenizer: AutoTokenizer
|
|
|
+ tokenizer: FishTokenizer
|
|
|
max_length: int = 1024
|
|
|
|
|
|
def __call__(self, examples):
|
|
|
@@ -388,7 +347,7 @@ class TextDataCollator:
|
|
|
_tokens = F.pad(
|
|
|
_tokens,
|
|
|
(0, max_tokens_length - tokens_length),
|
|
|
- value=self.tokenizer.eos_token_id,
|
|
|
+ value=self.tokenizer.get_token_id("<|end_of_text|>"),
|
|
|
)
|
|
|
_tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
|
|
|
_labels = F.pad(
|
|
|
@@ -446,7 +405,7 @@ class SemanticDataModule(LightningDataModule):
|
|
|
train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
|
|
|
val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
|
|
|
batch_size: int = 32,
|
|
|
- tokenizer: AutoTokenizer = None,
|
|
|
+ tokenizer: FishTokenizer = None,
|
|
|
max_length: int = 1024,
|
|
|
num_workers: int = 4,
|
|
|
):
|
|
|
@@ -483,14 +442,13 @@ if __name__ == "__main__":
|
|
|
|
|
|
ds = AutoTextSemanticInstructionDataset(
|
|
|
["data/protos"],
|
|
|
- tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
|
|
|
+ tokenizer=FishTokenizer("checkpoints/fish-speech-1.5/tokenizer.tiktoken"),
|
|
|
use_speaker=False,
|
|
|
interactive_prob=1.0,
|
|
|
skip_text_prob=0.5,
|
|
|
)
|
|
|
|
|
|
for i in ds:
|
|
|
- print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
|
|
|
- # i["labels"][0][i["labels"][0] == -100] = 0
|
|
|
- # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
|
|
|
+ # Please uncomment line 235 to visualize the tokenized message
|
|
|
+ print(i)
|
|
|
break
|