|
@@ -181,6 +181,7 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
repetition_prob: float = 0.0,
|
|
repetition_prob: float = 0.0,
|
|
|
max_length: int = 1024,
|
|
max_length: int = 1024,
|
|
|
tokenizer: AutoTokenizer = None,
|
|
tokenizer: AutoTokenizer = None,
|
|
|
|
|
+ use_speaker: bool = True,
|
|
|
):
|
|
):
|
|
|
"""
|
|
"""
|
|
|
Args:
|
|
Args:
|
|
@@ -199,6 +200,7 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
self.max_length = max_length
|
|
self.max_length = max_length
|
|
|
self.tokenizer = tokenizer
|
|
self.tokenizer = tokenizer
|
|
|
self.repetition_prob = repetition_prob
|
|
self.repetition_prob = repetition_prob
|
|
|
|
|
+ self.use_speaker = use_speaker
|
|
|
|
|
|
|
|
# Read all lines, and group by speaker
|
|
# Read all lines, and group by speaker
|
|
|
self.channel = grpc.insecure_channel(server)
|
|
self.channel = grpc.insecure_channel(server)
|
|
@@ -218,6 +220,8 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
for i in phones
|
|
for i in phones
|
|
|
]
|
|
]
|
|
|
)
|
|
)
|
|
|
|
|
+ else:
|
|
|
|
|
+ sentence = clean_text(sentence)
|
|
|
|
|
|
|
|
tokens = self.tokenizer.encode(
|
|
tokens = self.tokenizer.encode(
|
|
|
f"{sentence}",
|
|
f"{sentence}",
|
|
@@ -268,6 +272,9 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
final_text.append(text)
|
|
final_text.append(text)
|
|
|
final_semantic.append(sentence.semantics)
|
|
final_semantic.append(sentence.semantics)
|
|
|
|
|
|
|
|
|
|
+ if self.use_speaker is not None:
|
|
|
|
|
+ final_text = [f"[SPK: {response.name}]"] + final_text
|
|
|
|
|
+
|
|
|
final_text = "[INST] " + " ".join(final_text) + " [/INST]"
|
|
final_text = "[INST] " + " ".join(final_text) + " [/INST]"
|
|
|
encoded = self.tokenizer.encode(
|
|
encoded = self.tokenizer.encode(
|
|
|
final_text,
|
|
final_text,
|
|
@@ -441,15 +448,16 @@ if __name__ == "__main__":
|
|
|
|
|
|
|
|
from tqdm import tqdm
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
- # ds = AutoAugTextDataset(
|
|
|
|
|
- # tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
|
|
|
|
|
- # )
|
|
|
|
|
-
|
|
|
|
|
- ds = StreamTextDataset(
|
|
|
|
|
- prefix="en/",
|
|
|
|
|
|
|
+ ds = AutoAugTextDataset(
|
|
|
tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
|
|
tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
|
|
|
|
|
+ use_speaker=True,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
+ # ds = StreamTextDataset(
|
|
|
|
|
+ # prefix="en/",
|
|
|
|
|
+ # tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
|
|
|
|
|
+ # )
|
|
|
|
|
+
|
|
|
dm = TextDataModule(
|
|
dm = TextDataModule(
|
|
|
train_dataset=ds,
|
|
train_dataset=ds,
|
|
|
val_dataset=ds,
|
|
val_dataset=ds,
|