|
|
@@ -1,9 +1,6 @@
|
|
|
-import json
|
|
|
import random
|
|
|
-import re
|
|
|
from dataclasses import dataclass
|
|
|
from itertools import chain
|
|
|
-from pathlib import Path
|
|
|
from random import Random
|
|
|
from typing import Optional, Union
|
|
|
|
|
|
@@ -65,12 +62,16 @@ class StreamTextDataset(IterableDataset):
|
|
|
seed: int = 42,
|
|
|
parquet_batch_size: int = 10000,
|
|
|
repo: str = "uonlp/CulturaX",
|
|
|
+ max_length: int = 1024,
|
|
|
+ tokenizer: AutoTokenizer = None,
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
self.seed = seed
|
|
|
self.parquet_batch_size = parquet_batch_size
|
|
|
self.repo = repo
|
|
|
+ self.max_length = max_length
|
|
|
+ self.tokenizer = tokenizer
|
|
|
|
|
|
if files is None and prefix is None:
|
|
|
raise ValueError("Either files or prefix must be specified")
|
|
|
@@ -105,21 +106,48 @@ class StreamTextDataset(IterableDataset):
|
|
|
def parse_data(self, filename: str):
|
|
|
for data in self.parse_data_internal(filename):
|
|
|
text = data["text"]
|
|
|
- expression = re.compile(r"\[INST\] (.*) \[/INST\] (.*) </s>")
|
|
|
- match = expression.match(text)
|
|
|
|
|
|
- if match is None:
|
|
|
- continue
|
|
|
+ # 30% modeling phones
|
|
|
+ if random.random() < 0.3:
|
|
|
+ text = " ".join(
|
|
|
+ [
|
|
|
+ (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
|
|
|
+ for i in text
|
|
|
+ ]
|
|
|
+ )
|
|
|
+
|
|
|
+ # encode
|
|
|
+ tokens = self.tokenizer.encode(
|
|
|
+ text,
|
|
|
+ add_special_tokens=False,
|
|
|
+ truncation=False,
|
|
|
+ max_length=10**6,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Random choice self.max_length
|
|
|
+ if len(tokens) > self.max_length:
|
|
|
+ start = random.randint(0, len(tokens) - self.max_length)
|
|
|
+ tokens = tokens[start : start + self.max_length - 1]
|
|
|
|
|
|
- text = match.group(1)
|
|
|
- semantic = match.group(2)
|
|
|
+ tokens = (
|
|
|
+ [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
|
|
|
+ )
|
|
|
+ # Pad dims
|
|
|
+ placeholder_multi_codebook = torch.zeros((4, len(tokens)), dtype=torch.long)
|
|
|
|
|
|
- # Convert semantic to ids
|
|
|
- expression = re.compile(r"<semantic_(\d+)>")
|
|
|
- # 0 and 1 are reserved for <s> and </s>
|
|
|
- semantic = [0] + [int(i) + 2 for i in expression.findall(semantic)] + [1]
|
|
|
+ tokens = torch.concat(
|
|
|
+ [
|
|
|
+ torch.tensor([tokens], dtype=torch.long),
|
|
|
+ placeholder_multi_codebook,
|
|
|
+ ],
|
|
|
+ dim=0,
|
|
|
+ )
|
|
|
+ labels = tokens.clone()
|
|
|
+ tokens = tokens[:, :-1]
|
|
|
+ labels = labels[:, 1:]
|
|
|
+ labels[1:] = -100 # remove all placeholders
|
|
|
|
|
|
- yield {"text": text, "semantic": [semantic]}
|
|
|
+ yield {"tokens": tokens, "labels": labels}
|
|
|
|
|
|
def parse_data_internal(self, filename: str):
|
|
|
url = f"https://huggingface.co/datasets/{self.repo}/resolve/main/{filename}"
|
|
|
@@ -190,8 +218,6 @@ class AutoAugTextDataset(IterableDataset):
|
|
|
for i in phones
|
|
|
]
|
|
|
)
|
|
|
- else:
|
|
|
- sentence = clean_text(sentence)
|
|
|
|
|
|
tokens = self.tokenizer.encode(
|
|
|
f"{sentence}",
|
|
|
@@ -415,7 +441,12 @@ if __name__ == "__main__":
|
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
- ds = AutoAugTextDataset(
|
|
|
+ # ds = AutoAugTextDataset(
|
|
|
+ # tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
|
|
|
+ # )
|
|
|
+
|
|
|
+ ds = StreamTextDataset(
|
|
|
+ prefix="en/",
|
|
|
tokenizer=AutoTokenizer.from_pretrained("fishaudio/speech-lm-v1"),
|
|
|
)
|
|
|
|
|
|
@@ -423,10 +454,11 @@ if __name__ == "__main__":
|
|
|
train_dataset=ds,
|
|
|
val_dataset=ds,
|
|
|
tokenizer=ds.tokenizer,
|
|
|
- batch_size=16,
|
|
|
+ batch_size=2,
|
|
|
max_length=1024,
|
|
|
num_workers=0,
|
|
|
)
|
|
|
|
|
|
for batch in tqdm(dm.train_dataloader()):
|
|
|
- pass
|
|
|
+ print(batch)
|
|
|
+ break
|