|
|
@@ -199,6 +199,7 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
mix_text_phone_prob: float = 0.5,
|
|
|
use_negative_samples: bool = False,
|
|
|
num_codebooks: Optional[int] = None,
|
|
|
+ use_delay_pattern: bool = True,
|
|
|
):
|
|
|
"""
|
|
|
Args:
|
|
|
@@ -216,6 +217,7 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
mix_text_phone_prob: probability to mix text and phones, if this is 0, then it will be pure text or pure phones
|
|
|
use_negative_samples: generate negative samples
|
|
|
num_codebooks: number of codebooks, if None, it will be automatically detected
|
|
|
+ use_delay_pattern: use delay pattern for codebooks
|
|
|
"""
|
|
|
|
|
|
super().__init__()
|
|
|
@@ -238,6 +240,7 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
self.mix_text_phone_prob = mix_text_phone_prob
|
|
|
self.use_negative_samples = use_negative_samples
|
|
|
self.num_codebooks = num_codebooks
|
|
|
+ self.use_delay_pattern = use_delay_pattern
|
|
|
|
|
|
if use_data_server is True:
|
|
|
self.channel = grpc.insecure_channel(server)
|
|
|
@@ -494,9 +497,12 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
bos_bias = 1 if add_bos else 0
|
|
|
|
|
|
# Pack the tokens and semantics (add <s> and </s> to semantic tokens)
|
|
|
+ pad_token_length = semantic_length + (
|
|
|
+ num_codebooks - 1 if self.use_delay_pattern else 0
|
|
|
+ )
|
|
|
tokens = (
|
|
|
encoded
|
|
|
- + [self.tokenizer.pad_token_id] * (semantic_length + num_codebooks - 1)
|
|
|
+ + [self.tokenizer.pad_token_id] * pad_token_length
|
|
|
+ [self.tokenizer.eos_token_id]
|
|
|
)
|
|
|
|
|
|
@@ -506,7 +512,8 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
# Codebook bos/padding: 0, eos: 1
|
|
|
# Implement delay pattern
|
|
|
codes = [
|
|
|
- [CODEBOOK_PAD_TOKEN_ID] * (prompt_length + bos_bias + i)
|
|
|
+ [CODEBOOK_PAD_TOKEN_ID]
|
|
|
+ * (prompt_length + bos_bias + (i if self.use_delay_pattern else 0))
|
|
|
for i in range(num_codebooks)
|
|
|
]
|
|
|
for segment in semantics:
|
|
|
@@ -515,7 +522,8 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
codes[book_idx].append(int(j) + 2)
|
|
|
|
|
|
for idx, book in enumerate(codes):
|
|
|
- book.extend([CODEBOOK_PAD_TOKEN_ID] * (len(codes) - idx - 1))
|
|
|
+ if self.use_delay_pattern:
|
|
|
+ book.extend([CODEBOOK_PAD_TOKEN_ID] * (len(codes) - idx - 1))
|
|
|
book.append(CODEBOOK_EOS_TOKEN_ID)
|
|
|
|
|
|
tokens = [tokens] + codes
|