Ver Fonte

Optimize text dataset

Lengyue há 2 anos atrás
pai
commit
d5bb7f526a
2 ficheiros alterados com 125 adições e 68 exclusões
  1. 111 62
      fish_speech/datasets/text.py
  2. 14 6
      tools/llama/build_dataset.py

+ 111 - 62
fish_speech/datasets/text.py

@@ -26,6 +26,9 @@ from fish_speech.utils.braceexpand import braceexpand
 
 log = RankedLogger(__name__, rank_zero_only=True)
 
+CODEBOOK_BOS_TOKEN_ID = 0
+CODEBOOK_EOS_TOKEN_ID = 1
+
 
 def split_by_rank_worker(files):
     # We need to know the total number of devices
@@ -171,6 +174,12 @@ class AutoAugTextDataset(IterableDataset):
     1. Random concatenate multiple sentences from the same speaker to form a longer sentence
     2. Automatically normalize the text
     3. Mix text and phones
+
+    For interactive mode, we use the following format (multiple sequences):
+    <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
+
+    For non-interactive mode, we use the following format (one long sequence):
+    <s> [INST] text [/INST] ... </s>
     """
 
     def __init__(
@@ -179,6 +188,7 @@ class AutoAugTextDataset(IterableDataset):
         seed: int = 42,
         phones_prob: float = 0.3,
         repetition_prob: float = 0.0,
+        interactive_prob: float = 0.5,
         max_length: int = 1024,
         tokenizer: AutoTokenizer = None,
         use_speaker: bool = True,
@@ -189,6 +199,7 @@ class AutoAugTextDataset(IterableDataset):
             seed: random seed
             phones_prob: probability to use phones
             repetition_prob: probability to repeat the same sentence
+            interactive_prob: probability to use interactive mode
             max_length: max length of the text
             tokenizer: tokenizer
         """
@@ -200,6 +211,7 @@ class AutoAugTextDataset(IterableDataset):
         self.max_length = max_length
         self.tokenizer = tokenizer
         self.repetition_prob = repetition_prob
+        self.interactive_prob = interactive_prob
         self.use_speaker = use_speaker
 
         # Read all lines, and group by speaker
@@ -258,6 +270,10 @@ class AutoAugTextDataset(IterableDataset):
             return None
 
         samples = list(response.samples)
+        idx = 0
+        use_interactive = random.random() < self.interactive_prob
+
+        all_tokens, all_labels = [], []
         while remaining_tokens > 0 and len(samples) > 0:
             if random.random() < self.repetition_prob:
                 # Repeat the same sentence
@@ -269,70 +285,107 @@ class AutoAugTextDataset(IterableDataset):
                 sentence.text, sentence.phones, mode=mode
             )
             remaining_tokens -= length + len(sentence.semantics[0].values)
-            final_text.append(text)
-            final_semantic.append(sentence.semantics)
 
-        if self.use_speaker:
-            final_text = [f"[SPK: {response.name}]"] + final_text
+            if use_interactive is False:
+                final_text.append(text)
+                final_semantic.append(sentence.semantics)
+            else:
+                # For interactive mode, we only apply speaker for the first sentence
+                # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
+                tokens, labels = self.pack_sentences(
+                    sentences=[text],
+                    semantics=[sentence.semantics],
+                    speaker=response.name if (self.use_speaker and idx == 0) else None,
+                    add_bos=idx == 0,
+                )
+
+                all_tokens.append(tokens)
+                all_labels.append(labels)
+
+            idx += 1
+
+        if use_interactive is False:
+            tokens, labels = self.pack_sentences(
+                final_text,
+                semantics=final_semantic,
+                speaker=None if self.use_speaker else sentence.speaker,
+                add_bos=True,
+            )
+        else:
+            print(all_tokens[0].shape)
+            tokens = torch.cat(all_tokens, dim=1)
+            labels = torch.cat(all_labels, dim=1)
+
+        # Verify that the length is correct
+        assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
+
+        # Verify only one <s> token
+        assert (tokens[:, 0] == self.tokenizer.bos_token_id).sum() == 1
+
+        return {"tokens": tokens, "labels": labels}
 
-        final_text = "[INST] " + " ".join(final_text) + " [/INST]"
+    def pack_sentences(
+        self,
+        sentences: list[str],
+        semantics=list,
+        speaker: Optional[str] = None,
+        add_bos: bool = True,
+    ):
+        if speaker is not None:
+            sentences = [f"[SPK: {speaker}]"] + sentences
+
+        final_text = "[INST] " + " ".join(sentences) + " [/INST]"
         encoded = self.tokenizer.encode(
             final_text,
             add_special_tokens=False,
             truncation=False,
             max_length=10**6,
         )
-        semantic_length = sum([len(i[0].values) for i in final_semantic])
+        semantic_length = sum([len(i[0].values) for i in semantics])
+        bos_bias = 1 if add_bos else 0
+
+        # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
+        tokens = (
+            encoded
+            + [self.tokenizer.pad_token_id] * semantic_length
+            + [self.tokenizer.eos_token_id]
+        )
 
-        # Single codebook
-        if len(final_semantic[0]) == 1:
-            semantic_tokens = [f"<s:{j}>" for i in final_semantic for j in i[0].values]
-            tokenized = self.tokenizer.encode(
-                f" ".join(semantic_tokens),
-                add_special_tokens=False,
-                truncation=False,
-                max_length=10**6,
-            )
+        if add_bos:
+            tokens = [self.tokenizer.bos_token_id] + tokens
 
-            # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
-            tokens = (
-                [self.tokenizer.bos_token_id]
-                + encoded
-                + tokenized
-                + [self.tokenizer.eos_token_id]
-            )
-            tokens = torch.tensor([tokens], dtype=torch.long)
-            labels = tokens.clone()
-            labels[0, : len(encoded) + 1] = -100  # Mask out the <s> and query tokens
-        else:
-            # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
-            tokens = (
-                [self.tokenizer.bos_token_id]
-                + encoded
-                + [self.tokenizer.pad_token_id] * semantic_length
-                + [self.tokenizer.eos_token_id]
-            )
-            codes = [[0] * (len(encoded) + 1) for _ in range(len(final_semantic[0]))]
-            for segment in final_semantic:
-                for book_idx, book in enumerate(segment):
-                    for j in book.values:
-                        codes[book_idx].append(int(j) + 2)
+        # Codebook bos/padding: 0, eos: 1
+        codes = [
+            [CODEBOOK_BOS_TOKEN_ID] * (len(encoded) + bos_bias)
+            for _ in range(len(semantics[0]))
+        ]
+        for segment in semantics:
+            for book_idx, book in enumerate(segment):
+                for j in book.values:
+                    codes[book_idx].append(int(j) + 2)
 
-            for book in codes:
-                book.append(1)
+        for book in codes:
+            book.append(CODEBOOK_EOS_TOKEN_ID)
 
-            tokens = [tokens] + codes
+        tokens = [tokens] + codes
 
-            tokens = torch.tensor(tokens, dtype=torch.long)
-            labels = tokens.clone()
-            labels[
-                1:, : len(encoded) + 1
-            ] = -100  # Mask out the <s> tokens for semantic
+        tokens = torch.tensor(tokens, dtype=torch.long)
+        labels = tokens.clone()
 
-        return {
-            "tokens": tokens[:, :-1],
-            "labels": labels[:, 1:],
-        }
+        # Mask out the <s> tokens for semantic, predict semantic tokens only
+        # Since we don't mask out the input tokens, the language modeling still works
+        labels[1:, : (len(encoded) + bos_bias)] = -100
+
+        tokens = tokens[:, :-1]
+        labels = labels[:, 1:]
+
+        # Verify the padding is correct, and the last token is eos
+        assert add_bos is False or tokens[0, 0] == self.tokenizer.bos_token_id
+        assert (tokens[1:, : len(encoded) + bos_bias] == CODEBOOK_BOS_TOKEN_ID).all()
+        assert labels[0, -1] == self.tokenizer.eos_token_id
+        assert (labels[1:, -1] == CODEBOOK_EOS_TOKEN_ID).all()
+
+        return tokens, labels
 
 
 @dataclass
@@ -346,18 +399,20 @@ class TextDataCollator:
             _tokens = example["tokens"][:, : self.max_length]
             _labels = example["labels"][:, : self.max_length]
             _attention_mask = torch.ones((self.max_length,), dtype=torch.bool)
-            _attention_mask[: _tokens.size(1)] = False
+            tokens_length = _tokens.size(1)
+            _attention_mask[:tokens_length] = False
 
-            assert _tokens.size(1) == _labels.size(
+            assert tokens_length == _labels.size(
                 1
-            ), f"{_tokens.size(1)} != {_labels.size(1)}"
+            ), f"{tokens_length} != {_labels.size(1)}"
 
-            if _tokens.size(1) < self.max_length:
+            if tokens_length < self.max_length:
                 _tokens = F.pad(
                     _tokens,
-                    (0, self.max_length - _tokens.size(1)),
+                    (0, self.max_length - tokens_length),
                     value=self.tokenizer.eos_token_id,
                 )
+                _tokens[1:, tokens_length:] = CODEBOOK_EOS_TOKEN_ID
                 _labels = F.pad(
                     _labels, (0, self.max_length - _labels.size(1)), value=-100
                 )
@@ -444,20 +499,14 @@ class TextDataModule(LightningDataModule):
 
 
 if __name__ == "__main__":
-    import json
-
     from tqdm import tqdm
 
     ds = AutoAugTextDataset(
         tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
         use_speaker=True,
+        interactive_prob=1.0,
     )
 
-    # ds = StreamTextDataset(
-    #     prefix="en/",
-    #     tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
-    # )
-
     dm = TextDataModule(
         train_dataset=ds,
         val_dataset=ds,

+ 14 - 6
tools/llama/build_dataset.py

@@ -28,17 +28,25 @@ def task_generator_yaml(config):
             row["group_parent_level"],
         )
 
+        if isinstance(parent_level, int):
+            parent_level = [parent_level]
+
         # Load the files
         files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
 
         grouped_files = defaultdict(list)
         for file in files:
-            if parent_level == 1:
-                p = file.parent.name
-            elif parent_level == 2:
-                p = file.parent.parent.name
-            else:
-                raise ValueError(f"Invalid parent level {parent_level}")
+            all_parents = []
+            pointer = file
+            while pointer.parent.name:
+                all_parents.append(pointer.parent.name)
+                pointer = pointer.parent
+
+            ps = []
+            for level in parent_level:
+                ps.append(all_parents[level - 1])
+
+            p = "-".join(ps)
             grouped_files[p].append(file)
 
         logger.info(f"Found {len(grouped_files)} groups in {root}")