|
|
@@ -198,6 +198,7 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
causual: bool = True,
|
|
|
mix_text_phone_prob: float = 0.5,
|
|
|
use_negative_samples: bool = False,
|
|
|
+ num_codebooks: Optional[int] = None,
|
|
|
):
|
|
|
"""
|
|
|
Args:
|
|
|
@@ -214,6 +215,7 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
causual: use causual sampling when using local data, disable will lead to random sampling
|
|
|
mix_text_phone_prob: probability to mix text and phones, if this is 0, then it will be pure text or pure phones
|
|
|
use_negative_samples: generate negative samples
|
|
|
+ num_codebooks: number of codebooks, if None, it will be automatically detected
|
|
|
"""
|
|
|
|
|
|
super().__init__()
|
|
|
@@ -235,6 +237,7 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
self.causual = causual
|
|
|
self.mix_text_phone_prob = mix_text_phone_prob
|
|
|
self.use_negative_samples = use_negative_samples
|
|
|
+ self.num_codebooks = num_codebooks
|
|
|
|
|
|
if use_data_server is True:
|
|
|
self.channel = grpc.insecure_channel(server)
|
|
|
@@ -484,7 +487,9 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
)
|
|
|
semantic_length = sum([len(i[0].values) for i in semantics])
|
|
|
prompt_length = len(encoded)
|
|
|
- num_codebooks = len(semantics[0])
|
|
|
+ num_codebooks = (
|
|
|
+ len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
|
|
|
+ )
|
|
|
|
|
|
bos_bias = 1 if add_bos else 0
|
|
|
|
|
|
@@ -505,7 +510,7 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
for i in range(num_codebooks)
|
|
|
]
|
|
|
for segment in semantics:
|
|
|
- for book_idx, book in enumerate(segment):
|
|
|
+ for book_idx, book in zip(range(num_codebooks), segment):
|
|
|
for j in book.values:
|
|
|
codes[book_idx].append(int(j) + 2)
|
|
|
|
|
|
@@ -520,8 +525,7 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
|
|
|
# Mask out the <s> tokens for semantic, predict semantic tokens only
|
|
|
# Since we don't mask out the input tokens, the language modeling still works
|
|
|
- # labels[1:, : (prompt_length + bos_bias)] = -100
|
|
|
- labels[:, : (prompt_length + bos_bias)] = -100
|
|
|
+ labels[1:, : (prompt_length + bos_bias)] = -100
|
|
|
|
|
|
tokens = tokens[:, :-1]
|
|
|
labels = labels[:, 1:]
|
|
|
@@ -677,6 +681,7 @@ if __name__ == "__main__":
|
|
|
interactive_prob=1.0,
|
|
|
phones_prob=1.0,
|
|
|
use_negative_samples=False,
|
|
|
+ num_codebooks=4,
|
|
|
)
|
|
|
|
|
|
# ds = AutoAugTextDataset(
|