Lengyue 2 лет назад
Родитель
Сommit
712908e673
1 измененных файлов с 23 добавлено и 23 удалено
  1. 23 23
      speech_lm/datasets/cultura_x.py

+ 23 - 23
speech_lm/datasets/cultura_x.py

@@ -1,15 +1,17 @@
-from dataclasses import dataclass
 import random
+from dataclasses import dataclass
+from logging import getLogger
+from random import Random
+
+import numpy as np
 import pandas as pd
-from speech_lm.utils.braceexpand import braceexpand
-from torch.utils.data import IterableDataset, get_worker_info
+import pyarrow.parquet as pq
+from datasets.download.streaming_download_manager import xopen
 from torch.distributed import get_rank, get_world_size, is_initialized
-from random import Random
-from logging import getLogger
-from huggingface_hub import hf_hub_download
+from torch.utils.data import IterableDataset, get_worker_info
 from transformers import AutoTokenizer
-import numpy as np
 
+from speech_lm.utils.braceexpand import braceexpand
 
 SUBSETS = {
     "en": "en_part_{00000..03071}",
@@ -21,11 +23,12 @@ log = getLogger(__name__)
 
 
 class CulturaXDataset(IterableDataset):
-    def __init__(self, lang: str, seed: int = 42):
+    def __init__(self, lang: str, seed: int = 42, parquet_batch_size: int = 10000):
         super().__init__()
 
         self.lang = lang
         self.seed = seed
+        self.parquet_batch_size = parquet_batch_size
 
         # Get sharded files
         self.files = sorted(list(braceexpand(f"{lang}/{SUBSETS[lang]}.parquet")))
@@ -68,21 +71,18 @@ class CulturaXDataset(IterableDataset):
                 log.exception(f"Failed to parse {filename}")
 
     def parse_data(self, filename: str):
-        fname = hf_hub_download(
-            "uonlp/CulturaX",
-            filename,
-            repo_type="dataset",
-        )
-
-        # Read the file
-        df = pd.read_parquet(fname)
-
-        # Shuffle the data
-        df = df.sample(frac=1.0)
-
-        # Yield the data
-        for text in df["text"]:
-            yield {"text": text}
+        url = f"https://huggingface.co/datasets/uonlp/CulturaX/resolve/main/{filename}"
+
+        with xopen(url, mode="rb") as stream:
+            parquet_file = pq.ParquetFile(stream)
+
+            for batch in parquet_file.iter_batches(
+                batch_size=self.parquet_batch_size, columns=["text"]
+            ):
+                # In-batch shuffling
+                texts = [{"text": text.as_py()} for text in batch["text"]]
+                random.shuffle(texts)
+                yield from texts
 
 
 @dataclass