|
|
@@ -13,7 +13,7 @@ from datasets.download.streaming_download_manager import xopen
|
|
|
from huggingface_hub import HfApi
|
|
|
from lightning import LightningDataModule
|
|
|
from torch.distributed import get_rank, get_world_size, is_initialized
|
|
|
-from torch.utils.data import DataLoader, IterableDataset, get_worker_info
|
|
|
+from torch.utils.data import DataLoader, Dataset, IterableDataset, get_worker_info
|
|
|
|
|
|
from fish_speech.conversation import (
|
|
|
CODEBOOK_PAD_TOKEN_ID,
|
|
|
@@ -59,7 +59,7 @@ def split_by_rank_worker(files):
|
|
|
return files
|
|
|
|
|
|
|
|
|
-class AutoTextSemanticInstructionDataset(IterableDataset):
|
|
|
+class AutoTextSemanticInstructionIterableDataset(IterableDataset):
|
|
|
"""
|
|
|
Auto Augment Dataset by Speaker
|
|
|
|
|
|
@@ -295,6 +295,214 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
|
|
|
return data
|
|
|
|
|
|
|
|
|
+class AutoTextSemanticInstructionDataset(Dataset):
|
|
|
+ """
|
|
|
+ Auto Augment Dataset by Speaker
|
|
|
+
|
|
|
+ 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
|
|
|
+ 2. Automatically normalize the text
|
|
|
+
|
|
|
+ For interactive mode, we use the following format (multiple sequences):
|
|
|
+ <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
|
|
|
+
|
|
|
+ For non-interactive mode, we use the following format (one long sequence):
|
|
|
+ <s> [INST] text [/INST] ... </s>
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ proto_files: list[str],
|
|
|
+ seed: int = 42,
|
|
|
+ interactive_prob: float = 0.5,
|
|
|
+ max_length: int = 1024,
|
|
|
+ tokenizer: FishTokenizer = None,
|
|
|
+ use_speaker: bool | float = True,
|
|
|
+ causal: bool = True,
|
|
|
+ num_codebooks: Optional[int] = None,
|
|
|
+ skip_text_prob: float = 0.0,
|
|
|
+ ):
|
|
|
+ """
|
|
|
+ Args:
|
|
|
+ proto_files: proto buf files if using local data
|
|
|
+ seed: random seed
|
|
|
+ interactive_prob: probability to use interactive mode
|
|
|
+ max_length: max length of the text
|
|
|
+ tokenizer: tokenizer
|
|
|
+ use_speaker: include speaker information in the prompt
|
|
|
+ causal: use causal sampling when using local data, disable will lead to random sampling
|
|
|
+ num_codebooks: number of codebooks, if None, it will be automatically detected
|
|
|
+ skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
|
|
|
+ """
|
|
|
+ super().__init__()
|
|
|
+
|
|
|
+ assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
|
|
|
+
|
|
|
+ self.seed = seed
|
|
|
+ self.max_length = max_length
|
|
|
+ self.tokenizer = tokenizer
|
|
|
+ self.interactive_prob = interactive_prob
|
|
|
+ self.use_speaker = use_speaker
|
|
|
+ self.proto_files = proto_files
|
|
|
+ self.causal = causal
|
|
|
+ self.num_codebooks = num_codebooks
|
|
|
+ self.skip_text_prob = skip_text_prob
|
|
|
+
|
|
|
+ self.data = []
|
|
|
+ self._init_data()
|
|
|
+
|
|
|
+ def _init_data(self):
|
|
|
+ expanded_proto_files = []
|
|
|
+ for filename in self.proto_files:
|
|
|
+ for i in braceexpand(filename):
|
|
|
+ i = Path(i)
|
|
|
+ if i.is_file():
|
|
|
+ expanded_proto_files.append(i)
|
|
|
+ elif i.is_dir():
|
|
|
+ expanded_proto_files.extend(i.rglob("*.proto"))
|
|
|
+ expanded_proto_files.extend(i.rglob("*.protos"))
|
|
|
+ else:
|
|
|
+ raise ValueError(f"{i} is not a file or directory")
|
|
|
+
|
|
|
+ expanded_proto_files = sorted(expanded_proto_files)
|
|
|
+ Random(self.seed).shuffle(expanded_proto_files)
|
|
|
+
|
|
|
+ groups = []
|
|
|
+ shard_proto_files = split_by_rank_worker(expanded_proto_files)
|
|
|
+ log.info(
|
|
|
+ f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
|
|
|
+ )
|
|
|
+
|
|
|
+ count = 0
|
|
|
+ for filename in shard_proto_files:
|
|
|
+ with open(filename, "rb") as f:
|
|
|
+ for text_data in read_pb_stream(f):
|
|
|
+ groups.append(text_data)
|
|
|
+ count += 1
|
|
|
+
|
|
|
+ log.info(f"Read total {count} groups of data")
|
|
|
+
|
|
|
+ for group in groups:
|
|
|
+ if len(group.sentences) == 0:
|
|
|
+ continue
|
|
|
+
|
|
|
+ samples = list(group.sentences)
|
|
|
+ for sentence in samples:
|
|
|
+ text = clean_text(random.choice(sentence.texts))
|
|
|
+
|
|
|
+ tokens, labels = self.pack_sentences(
|
|
|
+ sentences=[text],
|
|
|
+ semantics=[sentence.semantics],
|
|
|
+ skip_text=random.random() < self.skip_text_prob,
|
|
|
+ )
|
|
|
+
|
|
|
+ self.data.append({"tokens": tokens, "labels": labels})
|
|
|
+
|
|
|
+ random.Random(self.seed).shuffle(self.data)
|
|
|
+
|
|
|
+ def __len__(self):
|
|
|
+ return len(self.data)
|
|
|
+
|
|
|
+ def __getitem__(self, idx):
|
|
|
+ return self.data[idx]
|
|
|
+
|
|
|
+ def pack_sentences(
|
|
|
+ self,
|
|
|
+ sentences: list[str],
|
|
|
+ semantics: list,
|
|
|
+ skip_text: bool = False,
|
|
|
+ ):
|
|
|
+ messages = [
|
|
|
+ Message(
|
|
|
+ role="system",
|
|
|
+ parts=[TextPart(text="Speak out the provided text.")],
|
|
|
+ )
|
|
|
+ ]
|
|
|
+
|
|
|
+ cated_sentences = " ".join(sentences)
|
|
|
+ if skip_text:
|
|
|
+ cated_sentences = "<|skip_text|>"
|
|
|
+
|
|
|
+ messages.append(
|
|
|
+ Message(
|
|
|
+ role="user",
|
|
|
+ parts=[TextPart(text=cated_sentences)],
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ vq_codes = [x.values for x in semantics[0]]
|
|
|
+ vq_codes_tensor = torch.tensor(vq_codes).to(torch.int32)
|
|
|
+ vqpart = VQPart(codes=vq_codes_tensor)
|
|
|
+ messages.append(
|
|
|
+ Message(
|
|
|
+ role="assistant",
|
|
|
+ parts=[TextPart(text="<|voice|>"), vqpart],
|
|
|
+ cal_loss=True,
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ num_codebooks = (
|
|
|
+ len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
|
|
|
+ )
|
|
|
+
|
|
|
+ conversation = Conversation(messages=messages)
|
|
|
+ encoded = conversation.encode(
|
|
|
+ tokenizer=self.tokenizer,
|
|
|
+ )
|
|
|
+
|
|
|
+ tokens_raw = encoded.tokens
|
|
|
+ tokens = torch.zeros((num_codebooks + 1, len(tokens_raw)), dtype=torch.int)
|
|
|
+ tokens[0] = tokens_raw
|
|
|
+
|
|
|
+ vq_parts = encoded.vq_parts
|
|
|
+ vq_parts = [part.to(tokens.device) for part in vq_parts]
|
|
|
+ vq_parts = torch.cat(vq_parts, dim=1)
|
|
|
+ tokens[1:, encoded.vq_mask_tokens] = vq_parts
|
|
|
+
|
|
|
+ labels_raw = encoded.labels
|
|
|
+ labels = torch.full((num_codebooks + 1, len(labels_raw)), -100, dtype=torch.int)
|
|
|
+ labels[0, :] = labels_raw
|
|
|
+ labels[1:, encoded.vq_mask_labels] = vq_parts
|
|
|
+ labels[1:, -1:] = CODEBOOK_PAD_TOKEN_ID
|
|
|
+
|
|
|
+ tokens = tokens.long()
|
|
|
+ labels = labels.long()
|
|
|
+
|
|
|
+ assert (tokens[1:, ~(encoded.vq_mask_tokens)] == CODEBOOK_PAD_TOKEN_ID).all()
|
|
|
+ assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
|
|
|
+
|
|
|
+ return tokens, labels
|
|
|
+
|
|
|
+
|
|
|
+class InterleaveDataset(IterableDataset):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ datasets: list[IterableDataset],
|
|
|
+ probabilities: list[float],
|
|
|
+ seed: int = 42,
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+
|
|
|
+ self.datasets = datasets
|
|
|
+ self.probabilities = probabilities
|
|
|
+ self.seed = seed
|
|
|
+
|
|
|
+ def __iter__(self):
|
|
|
+ rng = np.random.default_rng(self.seed)
|
|
|
+ dataset_iterators = [iter(dataset) for dataset in self.datasets]
|
|
|
+
|
|
|
+ while True:
|
|
|
+ # Random choice one
|
|
|
+ dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
|
|
|
+ dataset_iterator = dataset_iterators[dataset_idx]
|
|
|
+
|
|
|
+ try:
|
|
|
+ yield next(dataset_iterator)
|
|
|
+ except StopIteration:
|
|
|
+ # Exhausted, create a new iterator
|
|
|
+ dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
|
|
|
+ yield next(dataset_iterators[dataset_idx])
|
|
|
+
|
|
|
+
|
|
|
@dataclass
|
|
|
class TextDataCollator:
|
|
|
tokenizer: FishTokenizer
|
|
|
@@ -369,41 +577,19 @@ class TextDataCollator:
|
|
|
}
|
|
|
|
|
|
|
|
|
-class InterleaveDataset(IterableDataset):
|
|
|
- def __init__(
|
|
|
- self,
|
|
|
- datasets: list[IterableDataset],
|
|
|
- probabilities: list[float],
|
|
|
- seed: int = 42,
|
|
|
- ):
|
|
|
- super().__init__()
|
|
|
-
|
|
|
- self.datasets = datasets
|
|
|
- self.probabilities = probabilities
|
|
|
- self.seed = seed
|
|
|
-
|
|
|
- def __iter__(self):
|
|
|
- rng = np.random.default_rng(self.seed)
|
|
|
- dataset_iterators = [iter(dataset) for dataset in self.datasets]
|
|
|
-
|
|
|
- while True:
|
|
|
- # Random choice one
|
|
|
- dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
|
|
|
- dataset_iterator = dataset_iterators[dataset_idx]
|
|
|
-
|
|
|
- try:
|
|
|
- yield next(dataset_iterator)
|
|
|
- except StopIteration:
|
|
|
- # Exhausted, create a new iterator
|
|
|
- dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
|
|
|
- yield next(dataset_iterators[dataset_idx])
|
|
|
-
|
|
|
-
|
|
|
class SemanticDataModule(LightningDataModule):
|
|
|
def __init__(
|
|
|
self,
|
|
|
- train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
|
|
|
- val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
|
|
|
+ train_dataset: Union[
|
|
|
+ AutoTextSemanticInstructionDataset,
|
|
|
+ AutoTextSemanticInstructionIterableDataset,
|
|
|
+ InterleaveDataset,
|
|
|
+ ],
|
|
|
+ val_dataset: Union[
|
|
|
+ AutoTextSemanticInstructionDataset,
|
|
|
+ AutoTextSemanticInstructionIterableDataset,
|
|
|
+ InterleaveDataset,
|
|
|
+ ],
|
|
|
batch_size: int = 32,
|
|
|
tokenizer: FishTokenizer = None,
|
|
|
max_length: int = 1024,
|
|
|
@@ -448,7 +634,6 @@ if __name__ == "__main__":
|
|
|
skip_text_prob=0.5,
|
|
|
)
|
|
|
|
|
|
- for i in ds:
|
|
|
+ for i in range(100):
|
|
|
# Please uncomment line 235 to visualize the tokenized message
|
|
|
- print(i)
|
|
|
- break
|
|
|
+ print(ds[i])
|