|
@@ -2,6 +2,7 @@ import random
|
|
|
from dataclasses import dataclass
|
|
from dataclasses import dataclass
|
|
|
from logging import getLogger
|
|
from logging import getLogger
|
|
|
from random import Random
|
|
from random import Random
|
|
|
|
|
+from typing import Optional
|
|
|
|
|
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
import pandas as pd
|
|
import pandas as pd
|
|
@@ -23,15 +24,28 @@ log = getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
class CulturaXDataset(IterableDataset):
|
|
class CulturaXDataset(IterableDataset):
|
|
|
- def __init__(self, lang: str, seed: int = 42, parquet_batch_size: int = 10000):
|
|
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self,
|
|
|
|
|
+ lang: Optional[str] = None,
|
|
|
|
|
+ seed: int = 42,
|
|
|
|
|
+ parquet_batch_size: int = 10000,
|
|
|
|
|
+ repo: str = "uonlp/CulturaX",
|
|
|
|
|
+ files: Optional[list[str]] = None,
|
|
|
|
|
+ ):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.lang = lang
|
|
self.lang = lang
|
|
|
self.seed = seed
|
|
self.seed = seed
|
|
|
self.parquet_batch_size = parquet_batch_size
|
|
self.parquet_batch_size = parquet_batch_size
|
|
|
|
|
+ self.repo = repo
|
|
|
|
|
+
|
|
|
|
|
+ if self.lang is not None:
|
|
|
|
|
+ files = sorted(list(braceexpand(f"{lang}/{SUBSETS[lang]}.parquet")))
|
|
|
|
|
+ else:
|
|
|
|
|
+ files = list(files)
|
|
|
|
|
|
|
|
# Get sharded files
|
|
# Get sharded files
|
|
|
- self.files = sorted(list(braceexpand(f"{lang}/{SUBSETS[lang]}.parquet")))
|
|
|
|
|
|
|
+ self.files = files
|
|
|
Random(seed).shuffle(self.files)
|
|
Random(seed).shuffle(self.files)
|
|
|
|
|
|
|
|
def get_data_splits(self, files):
|
|
def get_data_splits(self, files):
|
|
@@ -71,7 +85,7 @@ class CulturaXDataset(IterableDataset):
|
|
|
log.exception(f"Failed to parse {filename}: {e}")
|
|
log.exception(f"Failed to parse {filename}: {e}")
|
|
|
|
|
|
|
|
def parse_data(self, filename: str):
|
|
def parse_data(self, filename: str):
|
|
|
- url = f"https://huggingface.co/datasets/uonlp/CulturaX/resolve/main/{filename}"
|
|
|
|
|
|
|
+ url = f"https://huggingface.co/datasets/{self.repo}/resolve/main/{filename}"
|
|
|
|
|
|
|
|
with xopen(url, mode="rb") as stream:
|
|
with xopen(url, mode="rb") as stream:
|
|
|
parquet_file = pq.ParquetFile(stream)
|
|
parquet_file = pq.ParquetFile(stream)
|
|
@@ -91,16 +105,7 @@ class CulutreXCollator:
|
|
|
max_length: int = 512
|
|
max_length: int = 512
|
|
|
|
|
|
|
|
def __call__(self, examples):
|
|
def __call__(self, examples):
|
|
|
- texts = []
|
|
|
|
|
-
|
|
|
|
|
- for example in examples:
|
|
|
|
|
- text = example["text"]
|
|
|
|
|
-
|
|
|
|
|
- if len(text) <= self.max_length:
|
|
|
|
|
- texts.append(text)
|
|
|
|
|
- else:
|
|
|
|
|
- start = random.randint(0, len(text) - self.max_length)
|
|
|
|
|
- texts.append(text[start : start + self.max_length])
|
|
|
|
|
|
|
+ texts = [i["text"] for i in examples]
|
|
|
|
|
|
|
|
if self.tokenizer.pad_token is None:
|
|
if self.tokenizer.pad_token is None:
|
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
@@ -152,9 +157,12 @@ class InterleaveDataset(IterableDataset):
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
|
|
|
+ from speech_lm.datasets.wenet_vq import WenetVQDataset
|
|
|
|
|
+
|
|
|
dataset_en = CulturaXDataset("en")
|
|
dataset_en = CulturaXDataset("en")
|
|
|
dataset_ja = CulturaXDataset("ja")
|
|
dataset_ja = CulturaXDataset("ja")
|
|
|
- dataset = InterleaveDataset([dataset_en, dataset_ja], [0.5, 0.5])
|
|
|
|
|
|
|
+ dataset_wenet = WenetVQDataset()
|
|
|
|
|
+ dataset = InterleaveDataset([dataset_en, dataset_wenet], [0.5, 0.5])
|
|
|
collator = CulutreXCollator(AutoTokenizer.from_pretrained("gpt2"))
|
|
collator = CulutreXCollator(AutoTokenizer.from_pretrained("gpt2"))
|
|
|
|
|
|
|
|
for batch in DataLoader(dataset, batch_size=4, collate_fn=collator, num_workers=4):
|
|
for batch in DataLoader(dataset, batch_size=4, collate_fn=collator, num_workers=4):
|