|
|
@@ -1,15 +1,17 @@
|
|
|
-from dataclasses import dataclass
|
|
|
import random
|
|
|
+from dataclasses import dataclass
|
|
|
+from logging import getLogger
|
|
|
+from random import Random
|
|
|
+
|
|
|
+import numpy as np
|
|
|
import pandas as pd
|
|
|
-from speech_lm.utils.braceexpand import braceexpand
|
|
|
-from torch.utils.data import IterableDataset, get_worker_info
|
|
|
+import pyarrow.parquet as pq
|
|
|
+from datasets.download.streaming_download_manager import xopen
|
|
|
from torch.distributed import get_rank, get_world_size, is_initialized
|
|
|
-from random import Random
|
|
|
-from logging import getLogger
|
|
|
-from huggingface_hub import hf_hub_download
|
|
|
+from torch.utils.data import IterableDataset, get_worker_info
|
|
|
from transformers import AutoTokenizer
|
|
|
-import numpy as np
|
|
|
|
|
|
+from speech_lm.utils.braceexpand import braceexpand
|
|
|
|
|
|
SUBSETS = {
|
|
|
"en": "en_part_{00000..03071}",
|
|
|
@@ -21,11 +23,12 @@ log = getLogger(__name__)
|
|
|
|
|
|
|
|
|
class CulturaXDataset(IterableDataset):
|
|
|
- def __init__(self, lang: str, seed: int = 42):
|
|
|
+ def __init__(self, lang: str, seed: int = 42, parquet_batch_size: int = 10000):
|
|
|
super().__init__()
|
|
|
|
|
|
self.lang = lang
|
|
|
self.seed = seed
|
|
|
+ self.parquet_batch_size = parquet_batch_size
|
|
|
|
|
|
# Get sharded files
|
|
|
self.files = sorted(list(braceexpand(f"{lang}/{SUBSETS[lang]}.parquet")))
|
|
|
@@ -68,21 +71,18 @@ class CulturaXDataset(IterableDataset):
|
|
|
log.exception(f"Failed to parse {filename}")
|
|
|
|
|
|
def parse_data(self, filename: str):
|
|
|
- fname = hf_hub_download(
|
|
|
- "uonlp/CulturaX",
|
|
|
- filename,
|
|
|
- repo_type="dataset",
|
|
|
- )
|
|
|
-
|
|
|
- # Read the file
|
|
|
- df = pd.read_parquet(fname)
|
|
|
-
|
|
|
- # Shuffle the data
|
|
|
- df = df.sample(frac=1.0)
|
|
|
-
|
|
|
- # Yield the data
|
|
|
- for text in df["text"]:
|
|
|
- yield {"text": text}
|
|
|
+ url = f"https://huggingface.co/datasets/uonlp/CulturaX/resolve/main/{filename}"
|
|
|
+
|
|
|
+ with xopen(url, mode="rb") as stream:
|
|
|
+ parquet_file = pq.ParquetFile(stream)
|
|
|
+
|
|
|
+ for batch in parquet_file.iter_batches(
|
|
|
+ batch_size=self.parquet_batch_size, columns=["text"]
|
|
|
+ ):
|
|
|
+ # In-batch shuffling
|
|
|
+ texts = [{"text": text.as_py()} for text in batch["text"]]
|
|
|
+ random.shuffle(texts)
|
|
|
+ yield from texts
|
|
|
|
|
|
|
|
|
@dataclass
|