Procházet zdrojové kódy

[feature]add dataset classs (#775)

* [feature]add dataset classs

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Whale and Dolphin před 1 rokem
rodič
revize
7902e408c8

+ 2 - 2
fish_speech/configs/text2semantic_finetune.yaml

@@ -26,7 +26,7 @@ tokenizer:
 
 # Dataset Configuration
 train_dataset:
-  _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
+  _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset
   proto_files:
     - data/protos
   tokenizer: ${tokenizer}
@@ -36,7 +36,7 @@ train_dataset:
   interactive_prob: 0.7
 
 val_dataset:
-  _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
+  _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset
   proto_files:
     - data/protos
   tokenizer: ${tokenizer}

+ 222 - 37
fish_speech/datasets/semantic.py

@@ -13,7 +13,7 @@ from datasets.download.streaming_download_manager import xopen
 from huggingface_hub import HfApi
 from lightning import LightningDataModule
 from torch.distributed import get_rank, get_world_size, is_initialized
-from torch.utils.data import DataLoader, IterableDataset, get_worker_info
+from torch.utils.data import DataLoader, Dataset, IterableDataset, get_worker_info
 
 from fish_speech.conversation import (
     CODEBOOK_PAD_TOKEN_ID,
@@ -59,7 +59,7 @@ def split_by_rank_worker(files):
     return files
 
 
-class AutoTextSemanticInstructionDataset(IterableDataset):
+class AutoTextSemanticInstructionIterableDataset(IterableDataset):
     """
     Auto Augment Dataset by Speaker
 
@@ -295,6 +295,214 @@ class AutoTextSemanticInstructionDataset(IterableDataset):
         return data
 
 
+class AutoTextSemanticInstructionDataset(Dataset):
+    """
+    Auto Augment Dataset by Speaker
+
+    1. Random concatenate multiple sentences from the same speaker to form a longer sentence
+    2. Automatically normalize the text
+
+    For interactive mode, we use the following format (multiple sequences):
+    <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
+
+    For non-interactive mode, we use the following format (one long sequence):
+    <s> [INST] text [/INST] ... </s>
+    """
+
+    def __init__(
+        self,
+        proto_files: list[str],
+        seed: int = 42,
+        interactive_prob: float = 0.5,
+        max_length: int = 1024,
+        tokenizer: FishTokenizer = None,
+        use_speaker: bool | float = True,
+        causal: bool = True,
+        num_codebooks: Optional[int] = None,
+        skip_text_prob: float = 0.0,
+    ):
+        """
+        Args:
+            proto_files: proto buf files if using local data
+            seed: random seed
+            interactive_prob: probability to use interactive mode
+            max_length: max length of the text
+            tokenizer: tokenizer
+            use_speaker: include speaker information in the prompt
+            causal: use causal sampling when using local data, disable will lead to random sampling
+            num_codebooks: number of codebooks, if None, it will be automatically detected
+            skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
+        """
+        super().__init__()
+
+        assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
+
+        self.seed = seed
+        self.max_length = max_length
+        self.tokenizer = tokenizer
+        self.interactive_prob = interactive_prob
+        self.use_speaker = use_speaker
+        self.proto_files = proto_files
+        self.causal = causal
+        self.num_codebooks = num_codebooks
+        self.skip_text_prob = skip_text_prob
+
+        self.data = []
+        self._init_data()
+
+    def _init_data(self):
+        expanded_proto_files = []
+        for filename in self.proto_files:
+            for i in braceexpand(filename):
+                i = Path(i)
+                if i.is_file():
+                    expanded_proto_files.append(i)
+                elif i.is_dir():
+                    expanded_proto_files.extend(i.rglob("*.proto"))
+                    expanded_proto_files.extend(i.rglob("*.protos"))
+                else:
+                    raise ValueError(f"{i} is not a file or directory")
+
+        expanded_proto_files = sorted(expanded_proto_files)
+        Random(self.seed).shuffle(expanded_proto_files)
+
+        groups = []
+        shard_proto_files = split_by_rank_worker(expanded_proto_files)
+        log.info(
+            f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
+        )
+
+        count = 0
+        for filename in shard_proto_files:
+            with open(filename, "rb") as f:
+                for text_data in read_pb_stream(f):
+                    groups.append(text_data)
+                    count += 1
+
+        log.info(f"Read total {count} groups of data")
+
+        for group in groups:
+            if len(group.sentences) == 0:
+                continue
+
+            samples = list(group.sentences)
+            for sentence in samples:
+                text = clean_text(random.choice(sentence.texts))
+
+                tokens, labels = self.pack_sentences(
+                    sentences=[text],
+                    semantics=[sentence.semantics],
+                    skip_text=random.random() < self.skip_text_prob,
+                )
+
+                self.data.append({"tokens": tokens, "labels": labels})
+
+        random.Random(self.seed).shuffle(self.data)
+
+    def __len__(self):
+        return len(self.data)
+
+    def __getitem__(self, idx):
+        return self.data[idx]
+
+    def pack_sentences(
+        self,
+        sentences: list[str],
+        semantics: list,
+        skip_text: bool = False,
+    ):
+        messages = [
+            Message(
+                role="system",
+                parts=[TextPart(text="Speak out the provided text.")],
+            )
+        ]
+
+        cated_sentences = " ".join(sentences)
+        if skip_text:
+            cated_sentences = "<|skip_text|>"
+
+        messages.append(
+            Message(
+                role="user",
+                parts=[TextPart(text=cated_sentences)],
+            )
+        )
+
+        vq_codes = [x.values for x in semantics[0]]
+        vq_codes_tensor = torch.tensor(vq_codes).to(torch.int32)
+        vqpart = VQPart(codes=vq_codes_tensor)
+        messages.append(
+            Message(
+                role="assistant",
+                parts=[TextPart(text="<|voice|>"), vqpart],
+                cal_loss=True,
+            )
+        )
+
+        num_codebooks = (
+            len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
+        )
+
+        conversation = Conversation(messages=messages)
+        encoded = conversation.encode(
+            tokenizer=self.tokenizer,
+        )
+
+        tokens_raw = encoded.tokens
+        tokens = torch.zeros((num_codebooks + 1, len(tokens_raw)), dtype=torch.int)
+        tokens[0] = tokens_raw
+
+        vq_parts = encoded.vq_parts
+        vq_parts = [part.to(tokens.device) for part in vq_parts]
+        vq_parts = torch.cat(vq_parts, dim=1)
+        tokens[1:, encoded.vq_mask_tokens] = vq_parts
+
+        labels_raw = encoded.labels
+        labels = torch.full((num_codebooks + 1, len(labels_raw)), -100, dtype=torch.int)
+        labels[0, :] = labels_raw
+        labels[1:, encoded.vq_mask_labels] = vq_parts
+        labels[1:, -1:] = CODEBOOK_PAD_TOKEN_ID
+
+        tokens = tokens.long()
+        labels = labels.long()
+
+        assert (tokens[1:, ~(encoded.vq_mask_tokens)] == CODEBOOK_PAD_TOKEN_ID).all()
+        assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
+
+        return tokens, labels
+
+
+class InterleaveDataset(IterableDataset):
+    def __init__(
+        self,
+        datasets: list[IterableDataset],
+        probabilities: list[float],
+        seed: int = 42,
+    ):
+        super().__init__()
+
+        self.datasets = datasets
+        self.probabilities = probabilities
+        self.seed = seed
+
+    def __iter__(self):
+        rng = np.random.default_rng(self.seed)
+        dataset_iterators = [iter(dataset) for dataset in self.datasets]
+
+        while True:
+            # Random choice one
+            dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
+            dataset_iterator = dataset_iterators[dataset_idx]
+
+            try:
+                yield next(dataset_iterator)
+            except StopIteration:
+                # Exhausted, create a new iterator
+                dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
+                yield next(dataset_iterators[dataset_idx])
+
+
 @dataclass
 class TextDataCollator:
     tokenizer: FishTokenizer
@@ -369,41 +577,19 @@ class TextDataCollator:
         }
 
 
-class InterleaveDataset(IterableDataset):
-    def __init__(
-        self,
-        datasets: list[IterableDataset],
-        probabilities: list[float],
-        seed: int = 42,
-    ):
-        super().__init__()
-
-        self.datasets = datasets
-        self.probabilities = probabilities
-        self.seed = seed
-
-    def __iter__(self):
-        rng = np.random.default_rng(self.seed)
-        dataset_iterators = [iter(dataset) for dataset in self.datasets]
-
-        while True:
-            # Random choice one
-            dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
-            dataset_iterator = dataset_iterators[dataset_idx]
-
-            try:
-                yield next(dataset_iterator)
-            except StopIteration:
-                # Exhausted, create a new iterator
-                dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
-                yield next(dataset_iterators[dataset_idx])
-
-
 class SemanticDataModule(LightningDataModule):
     def __init__(
         self,
-        train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
-        val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
+        train_dataset: Union[
+            AutoTextSemanticInstructionDataset,
+            AutoTextSemanticInstructionIterableDataset,
+            InterleaveDataset,
+        ],
+        val_dataset: Union[
+            AutoTextSemanticInstructionDataset,
+            AutoTextSemanticInstructionIterableDataset,
+            InterleaveDataset,
+        ],
         batch_size: int = 32,
         tokenizer: FishTokenizer = None,
         max_length: int = 1024,
@@ -448,7 +634,6 @@ if __name__ == "__main__":
         skip_text_prob=0.5,
     )
 
-    for i in ds:
+    for i in range(100):
         # Please uncomment line 235 to visualize the tokenized message
-        print(i)
-        break
+        print(ds[i])