Преглед изворни кода

Implement optional delay pattern

Lengyue пре 2 година
родитељ
комит
073d2266e4
1 измењених фајлова са 11 додато и 3 уклоњено
  1. 11 3
      fish_speech/datasets/text.py

+ 11 - 3
fish_speech/datasets/text.py

@@ -199,6 +199,7 @@ class AutoAugTextDataset(IterableDataset):
         mix_text_phone_prob: float = 0.5,
         use_negative_samples: bool = False,
         num_codebooks: Optional[int] = None,
+        use_delay_pattern: bool = True,
     ):
         """
         Args:
@@ -216,6 +217,7 @@ class AutoAugTextDataset(IterableDataset):
             mix_text_phone_prob: probability to mix text and phones, if this is 0, then it will be pure text or pure phones
             use_negative_samples: generate negative samples
             num_codebooks: number of codebooks, if None, it will be automatically detected
+            use_delay_pattern: use delay pattern for codebooks
         """
 
         super().__init__()
@@ -238,6 +240,7 @@ class AutoAugTextDataset(IterableDataset):
         self.mix_text_phone_prob = mix_text_phone_prob
         self.use_negative_samples = use_negative_samples
         self.num_codebooks = num_codebooks
+        self.use_delay_pattern = use_delay_pattern
 
         if use_data_server is True:
             self.channel = grpc.insecure_channel(server)
@@ -494,9 +497,12 @@ class AutoAugTextDataset(IterableDataset):
         bos_bias = 1 if add_bos else 0
 
         # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
+        pad_token_length = semantic_length + (
+            num_codebooks - 1 if self.use_delay_pattern else 0
+        )
         tokens = (
             encoded
-            + [self.tokenizer.pad_token_id] * (semantic_length + num_codebooks - 1)
+            + [self.tokenizer.pad_token_id] * pad_token_length
             + [self.tokenizer.eos_token_id]
         )
 
@@ -506,7 +512,8 @@ class AutoAugTextDataset(IterableDataset):
         # Codebook bos/padding: 0, eos: 1
         # Implement delay pattern
         codes = [
-            [CODEBOOK_PAD_TOKEN_ID] * (prompt_length + bos_bias + i)
+            [CODEBOOK_PAD_TOKEN_ID]
+            * (prompt_length + bos_bias + (i if self.use_delay_pattern else 0))
             for i in range(num_codebooks)
         ]
         for segment in semantics:
@@ -515,7 +522,8 @@ class AutoAugTextDataset(IterableDataset):
                     codes[book_idx].append(int(j) + 2)
 
         for idx, book in enumerate(codes):
-            book.extend([CODEBOOK_PAD_TOKEN_ID] * (len(codes) - idx - 1))
+            if self.use_delay_pattern:
+                book.extend([CODEBOOK_PAD_TOKEN_ID] * (len(codes) - idx - 1))
             book.append(CODEBOOK_EOS_TOKEN_ID)
 
         tokens = [tokens] + codes