|
|
@@ -5,7 +5,6 @@ from pathlib import Path
|
|
|
from random import Random
|
|
|
from typing import Optional, Union
|
|
|
|
|
|
-import grpc
|
|
|
import numpy as np
|
|
|
import pyarrow.parquet as pq
|
|
|
import torch
|
|
|
@@ -27,6 +26,7 @@ log = RankedLogger(__name__, rank_zero_only=True)
|
|
|
|
|
|
CODEBOOK_PAD_TOKEN_ID = 0
|
|
|
CODEBOOK_EOS_TOKEN_ID = 1
|
|
|
+SKIP_TEXT_STRING = "<|skip_text|>"
|
|
|
|
|
|
|
|
|
def split_by_rank_worker(files):
|
|
|
@@ -182,6 +182,7 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
causual: bool = True,
|
|
|
use_negative_samples: bool = False,
|
|
|
num_codebooks: Optional[int] = None,
|
|
|
+ skip_text_prob: float = 0.0,
|
|
|
):
|
|
|
"""
|
|
|
Args:
|
|
|
@@ -194,6 +195,7 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
causual: use causual sampling when using local data, disable will lead to random sampling
|
|
|
use_negative_samples: generate negative samples
|
|
|
num_codebooks: number of codebooks, if None, it will be automatically detected
|
|
|
+ skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
|
|
|
"""
|
|
|
|
|
|
super().__init__()
|
|
|
@@ -209,6 +211,7 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
self.causual = causual
|
|
|
self.use_negative_samples = use_negative_samples
|
|
|
self.num_codebooks = num_codebooks
|
|
|
+ self.skip_text_prob = skip_text_prob
|
|
|
|
|
|
self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
|
|
|
self.groups = None
|
|
|
@@ -344,6 +347,7 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
semantics=[sentence.semantics],
|
|
|
speaker=response.name if use_speaker else None,
|
|
|
add_bos=idx == 0,
|
|
|
+ skip_text=random.random() < self.skip_text_prob,
|
|
|
)
|
|
|
|
|
|
all_tokens.append(tokens)
|
|
|
@@ -442,14 +446,19 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
def pack_sentences(
|
|
|
self,
|
|
|
sentences: list[str],
|
|
|
- semantics=list,
|
|
|
+ semantics: list,
|
|
|
speaker: Optional[str] = None,
|
|
|
add_bos: bool = True,
|
|
|
+ skip_text: bool = False,
|
|
|
):
|
|
|
if speaker is None:
|
|
|
speaker = "assistant"
|
|
|
|
|
|
- final_text = "<|im_start|>user<|im_sep|>" + " ".join(sentences) + "<|im_end|>"
|
|
|
+ cated_sentences = " ".join(sentences)
|
|
|
+ if skip_text:
|
|
|
+ cated_sentences = SKIP_TEXT_STRING
|
|
|
+
|
|
|
+ final_text = "<|im_start|>user<|im_sep|>" + cated_sentences + "<|im_end|>"
|
|
|
final_text = final_text + f"<|im_start|>{speaker}<|im_sep|>"
|
|
|
|
|
|
encoded = self.tokenizer.encode(
|
|
|
@@ -496,6 +505,11 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
tokens = torch.tensor(tokens, dtype=torch.long)
|
|
|
labels = tokens.clone()
|
|
|
|
|
|
+ if skip_text:
|
|
|
+ # If text is not provided, the sentence is used for condition only, all labels are -100
|
|
|
+ torch.fill_(labels, -100)
|
|
|
+ return tokens, labels
|
|
|
+
|
|
|
# Mask out the <s> tokens for semantic, predict semantic tokens only
|
|
|
# Since we don't mask out the input tokens, the language modeling still works
|
|
|
labels[1:, : (prompt_length + bos_bias)] = -100
|
|
|
@@ -663,6 +677,7 @@ if __name__ == "__main__":
|
|
|
use_speaker=False,
|
|
|
interactive_prob=1.0,
|
|
|
use_negative_samples=False,
|
|
|
+ skip_text_prob=0.5,
|
|
|
)
|
|
|
|
|
|
for i in ds:
|