Lengyue 2 лет назад
Родитель
Сommit
cdc3b88c27
1 измененных файлов с 13 добавлено и 14 удалено
  1. 13 14
      fish_speech/datasets/text.py

+ 13 - 14
fish_speech/datasets/text.py

@@ -27,7 +27,7 @@ from fish_speech.utils.braceexpand import braceexpand
 
 log = RankedLogger(__name__, rank_zero_only=True)
 
-CODEBOOK_BOS_TOKEN_ID = 0
+CODEBOOK_PAD_TOKEN_ID = 0
 CODEBOOK_EOS_TOKEN_ID = 1
 
 
@@ -476,11 +476,6 @@ class AutoAugTextDataset(IterableDataset):
             sentences = [f"[SPK: {speaker}]"] + sentences
 
         final_text = "[INST] " + " ".join(sentences) + " [/INST]"
-
-        for segment in semantics:
-            for j in segment[0].values:
-                final_text += f" <s:{int(j)}>"
-
         encoded = self.tokenizer.encode(
             final_text,
             add_special_tokens=False,
@@ -488,14 +483,15 @@ class AutoAugTextDataset(IterableDataset):
             max_length=10**6,
         )
         semantic_length = sum([len(i[0].values) for i in semantics])
-        prompt_length = len(encoded) - semantic_length
+        prompt_length = len(encoded)
+        num_codebooks = len(semantics[0])
 
         bos_bias = 1 if add_bos else 0
 
         # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
         tokens = (
             encoded
-            # + [self.tokenizer.pad_token_id] * semantic_length
+            + [self.tokenizer.pad_token_id] * (semantic_length + num_codebooks - 1)
             + [self.tokenizer.eos_token_id]
         )
 
@@ -503,16 +499,18 @@ class AutoAugTextDataset(IterableDataset):
             tokens = [self.tokenizer.bos_token_id] + tokens
 
         # Codebook bos/padding: 0, eos: 1
+        # Implement delay pattern
         codes = [
-            [CODEBOOK_BOS_TOKEN_ID] * (prompt_length + bos_bias)
-            for _ in range(len(semantics[0]))
+            [CODEBOOK_PAD_TOKEN_ID] * (prompt_length + bos_bias + i)
+            for i in range(num_codebooks)
         ]
         for segment in semantics:
             for book_idx, book in enumerate(segment):
                 for j in book.values:
                     codes[book_idx].append(int(j) + 2)
 
-        for book in codes:
+        for idx, book in enumerate(codes):
+            book.extend([CODEBOOK_PAD_TOKEN_ID] * (len(codes) - idx - 1))
             book.append(CODEBOOK_EOS_TOKEN_ID)
 
         tokens = [tokens] + codes
@@ -522,14 +520,15 @@ class AutoAugTextDataset(IterableDataset):
 
         # Mask out the <s> tokens for semantic, predict semantic tokens only
         # Since we don't mask out the input tokens, the language modeling still works
-        labels[1:, : (prompt_length + bos_bias)] = -100
+        # labels[1:, : (prompt_length + bos_bias)] = -100
+        labels[:, : (prompt_length + bos_bias)] = -100
 
         tokens = tokens[:, :-1]
         labels = labels[:, 1:]
 
         # Verify the padding is correct, and the last token is eos
         assert add_bos is False or tokens[0, 0] == self.tokenizer.bos_token_id
-        assert (tokens[1:, : prompt_length + bos_bias] == CODEBOOK_BOS_TOKEN_ID).all()
+        assert (tokens[1:, : prompt_length + bos_bias] == CODEBOOK_PAD_TOKEN_ID).all()
         assert labels[0, -1] == self.tokenizer.eos_token_id
         assert (labels[1:, -1] == CODEBOOK_EOS_TOKEN_ID).all()
 
@@ -677,7 +676,7 @@ if __name__ == "__main__":
         use_speaker=True,
         interactive_prob=1.0,
         phones_prob=1.0,
-        use_negative_samples=True,
+        use_negative_samples=False,
     )
 
     # ds = AutoAugTextDataset(