Procházet zdrojové kódy

Update much robust dataloader

Lengyue před 2 roky
rodič
revize
7650b2ac43

+ 1 - 0
.gitignore

@@ -5,3 +5,4 @@ __pycache__
 /results
 /data
 /*.test.sh
+*.filelist

+ 3 - 0
dockerfile

@@ -31,4 +31,7 @@ WORKDIR /exp
 COPY requirements.txt .
 RUN pip3 install -r requirements.txt && pip3 install encodec --no-deps
 
+COPY . .
+RUN pip3 install -e .
+
 CMD /bin/zsh

+ 0 - 1
requirements.txt

@@ -4,7 +4,6 @@ bitsandbytes>=0.41.1
 peft>=0.5.0
 lightning>=2.0.9.post0
 hydra-core>=1.3.2
-pyrootutils>=1.0.4
 tensorboard>=2.14.1
 natsort>=8.4.0
 einops>=0.7.0

+ 7 - 0
setup.py

@@ -0,0 +1,7 @@
+from setuptools import find_packages, setup
+
+setup(
+    name="speech-lm",
+    version="0.0.1",
+    packages=find_packages(include=["speech_lm", "speech_lm.*"]),
+)

+ 16 - 5
speech_lm/configs/pretrain.yaml

@@ -39,16 +39,27 @@ schedule:
   gradient_accumulation_steps: "${eval: ${schedule.batch_size} // ${schedule.micro_batch_size}}"
   clip_grad_norm: 1.0
 
+dataset:
+  _target_: datasets.interleave_datasets
+  datasets:
+    - _target_: speech_lm.datasets.cultura_x.CulturaXDataset
+      lang: 'en'
+    - _target_: speech_lm.datasets.cultura_x.CulturaXDataset
+      lang: 'zh'
+    - _target_: speech_lm.datasets.cultura_x.CulturaXDataset
+      lang: 'ja'
+  probabilities: [0.4, 0.3, 0.3]
+  seed: 42
+
 dataloader:
   _target_: torch.utils.data.DataLoader
-  dataset: 
-    _target_: speech_lm.dataset.build_dataset
-    tokenizer: ${tokenizer}
-    max_length: ${schedule.max_length}
+  dataset: ${dataset}
   batch_size: ${schedule.micro_batch_size}
   num_workers: 4
   collate_fn:
-    _target_: transformers.DefaultDataCollator
+    _target_: speech_lm.datasets.cultura_x.CulutreXCollator
+    tokenizer: ${tokenizer}
+    max_length: ${schedule.max_length}
 
 optimizer:
   _target_: torch.optim.AdamW

+ 0 - 62
speech_lm/dataset.py

@@ -1,62 +0,0 @@
-import random
-from functools import partial
-
-from datasets import IterableDataset, interleave_datasets, load_dataset
-from datasets.distributed import split_dataset_by_node
-from torch.distributed import get_rank, get_world_size, is_initialized
-
-
-def encode(examples, tokenizer, max_length=512):
-    # Random choice a 512 token window for each example
-    texts = []
-    for text in examples["text"]:
-        if len(text) <= max_length:
-            texts.append(text)
-        else:
-            start = random.randint(0, len(text) - max_length)
-            texts.append(text[start : start + max_length])
-
-    data = tokenizer(
-        texts,
-        truncation=True,
-        padding="max_length",
-        max_length=max_length,
-        return_tensors="pt",
-    )
-    data["labels"] = data["input_ids"].clone()
-    data["labels"][data["attention_mask"] == 0] = -100
-
-    return data
-
-
-def build_dataset(tokenizer, max_length=512):
-    en_dataset = load_dataset("uonlp/CulturaX", "en", split="train", streaming=True)
-    ja_dataset = load_dataset("uonlp/CulturaX", "ja", split="train", streaming=True)
-    zh_dataset = load_dataset("uonlp/CulturaX", "zh", split="train", streaming=True)
-
-    multilingual_dataset: IterableDataset = interleave_datasets(
-        [en_dataset, ja_dataset, zh_dataset], probabilities=[0.4, 0.3, 0.3], seed=42
-    )
-
-    # DDP
-    if is_initialized():
-        multilingual_dataset = split_dataset_by_node(
-            multilingual_dataset,
-            rank=get_rank(),
-            world_size=get_world_size(),
-        )
-
-    multilingual_dataset = multilingual_dataset.shuffle(seed=42, buffer_size=10000)
-
-    multilingual_dataset = multilingual_dataset.map(
-        partial(encode, tokenizer=tokenizer, max_length=max_length),
-        batched=True,
-        remove_columns=multilingual_dataset.column_names,
-    )
-
-    return multilingual_dataset
-
-
-if __name__ == "__main__":
-    dataset = build_dataset()
-    print(list(dataset.take(16)))

+ 124 - 0
speech_lm/datasets/cultura_x.py

@@ -0,0 +1,124 @@
+from dataclasses import dataclass
+import random
+import pandas as pd
+from speech_lm.utils.braceexpand import braceexpand
+from torch.utils.data import IterableDataset, get_worker_info
+from torch.distributed import get_rank, get_world_size, is_initialized
+from random import Random
+from logging import getLogger
+from huggingface_hub import hf_hub_download
+from transformers import AutoTokenizer
+
+SUBSETS = {
+    "en": "en_part_{00000..03071}",
+    "zh": "zh_part_{00000..00319}",
+    "ja": "ja_part_{00000..00159}",
+}
+
+log = getLogger(__name__)
+
+
+class CulturaXDataset(IterableDataset):
+    def __init__(self, lang: str, seed: int = 42):
+        super().__init__()
+
+        self.lang = lang
+        self.seed = seed
+
+        # Get sharded files
+        self.files = sorted(list(braceexpand(f"{lang}/{SUBSETS[lang]}.parquet")))
+        Random(seed).shuffle(self.files)
+
+    def get_data_splits(self, files):
+        # We need to know the total number of devices
+        # to split the data properly
+
+        total_devices = 1
+        if is_initialized():
+            total_devices = get_world_size()
+        
+        worker_info = get_worker_info()
+        if worker_info is not None:
+            total_devices *= worker_info.num_workers
+        
+        if len(files) < total_devices:
+            # Repeat the files N times to match the number of devices
+            files = files * (total_devices // len(files) + 1)
+
+        # DDP
+        if is_initialized():
+            files = files[get_rank() :: get_world_size()]
+
+        # Split by worker
+        if worker_info is not None:
+            files = files[worker_info.id :: worker_info.num_workers]
+
+        return files
+
+    def __iter__(self):
+        files = self.get_data_splits(self.files)
+        random.shuffle(files)
+
+        for filename in files:
+            yield from self.parse_data(filename)
+
+    def parse_data(self, filename: str):
+        fname = hf_hub_download(
+            "uonlp/CulturaX",
+            filename,
+            repo_type="dataset",
+        )
+
+        # Read the file
+        df = pd.read_parquet(fname)
+
+        # Shuffle the data
+        df = df.sample(frac=1.0)
+
+        # Yield the data
+        for text in df["text"]:
+            yield {"text": text}
+
+
+@dataclass
+class CulutreXCollator:
+    tokenizer: AutoTokenizer
+    max_length: int = 512
+
+    def __call__(self, examples):
+        texts = []
+
+        for example in examples:
+            text = example["text"]
+
+            if len(text) <= self.max_length:
+                texts.append(text)
+            else:
+                start = random.randint(0, len(text) - self.max_length)
+                texts.append(text[start : start + self.max_length])
+
+        if self.tokenizer.pad_token is None:
+            self.tokenizer.pad_token = self.tokenizer.eos_token
+
+        data = self.tokenizer(
+            texts,
+            truncation=True,
+            padding=True,
+            max_length=self.max_length,
+            return_tensors="pt",
+        )
+
+        data["labels"] = data["input_ids"].clone()
+        data["labels"][data["attention_mask"] == 0] = -100
+
+        return data
+
+
+if __name__ == "__main__":
+    from torch.utils.data import DataLoader
+
+    dataset = CulturaXDataset("en")
+    collator = CulutreXCollator(AutoTokenizer.from_pretrained("gpt2"))
+
+    for batch in DataLoader(dataset, batch_size=4, collate_fn=collator, num_workers=4):
+        print(batch)

+ 4 - 7
speech_lm/train.py

@@ -1,26 +1,23 @@
 from pathlib import Path
 
 import hydra
-import pyrootutils
 import torch
 from lightning.fabric import Fabric
+from natsort import natsorted
 from omegaconf import DictConfig, OmegaConf
 from tqdm import tqdm
 from transformers import LlamaForCausalLM
 from transformers.utils import is_flash_attn_available
-from natsort import natsorted
+
+from speech_lm.logger import RankedLogger
 
 # Allow TF32 on Ampere GPUs
 torch.set_float32_matmul_precision("high")
 torch.backends.cudnn.allow_tf32 = True
 
-# register eval resolver and root
-pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+# register eval resolver
 OmegaConf.register_new_resolver("eval", eval)
 
-# flake8: noqa: E402
-from speech_lm.logger import RankedLogger
-
 log = RankedLogger(__name__, rank_zero_only=True)
 
 

+ 217 - 0
speech_lm/utils/braceexpand.py

@@ -0,0 +1,217 @@
+"""
+Bash-style brace expansion
+Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py
+License: MIT
+"""
+
+import re
+import string
+from itertools import chain, product
+from typing import Iterable, Iterator, Optional
+
+__all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"]
+
+
+class UnbalancedBracesError(ValueError):
+    pass
+
+
+alphabet = string.ascii_uppercase + string.ascii_lowercase
+
+int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$")
+char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$")
+escape_re = re.compile(r"\\(.)")
+
+
+def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]:
+    """braceexpand(pattern) -> iterator over generated strings
+
+    Returns an iterator over the strings resulting from brace expansion
+    of pattern. This function implements Brace Expansion as described in
+    bash(1), with the following limitations:
+
+    * A pattern containing unbalanced braces will raise an
+      UnbalancedBracesError exception. In bash, unbalanced braces will either
+      be partly expanded or ignored.
+
+    * A mixed-case character range like '{Z..a}' or '{a..Z}' will not
+      include the characters '[]^_`' between 'Z' and 'a'.
+
+    When escape is True (the default), characters in pattern can be
+    prefixed with a backslash to cause them not to be interpreted as
+    special characters for brace expansion (such as '{', '}', ',').
+    To pass through a a literal backslash, double it ('\\\\').
+
+    When escape is False, backslashes in pattern have no special
+    meaning and will be preserved in the output.
+
+    Examples:
+
+    >>> from braceexpand import braceexpand
+
+    # Integer range
+    >>> list(braceexpand('item{1..3}'))
+    ['item1', 'item2', 'item3']
+
+    # Character range
+    >>> list(braceexpand('{a..c}'))
+    ['a', 'b', 'c']
+
+    # Sequence
+    >>> list(braceexpand('index.html{,.backup}'))
+    ['index.html', 'index.html.backup']
+
+    # Nested patterns
+    >>> list(braceexpand('python{2.{5..7},3.{2,3}}'))
+    ['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3']
+
+    # Prefixing an integer with zero causes all numbers to be padded to
+    # the same width.
+    >>> list(braceexpand('{07..10}'))
+    ['07', '08', '09', '10']
+
+    # An optional increment can be specified for ranges.
+    >>> list(braceexpand('{a..g..2}'))
+    ['a', 'c', 'e', 'g']
+
+    # Ranges can go in both directions.
+    >>> list(braceexpand('{4..1}'))
+    ['4', '3', '2', '1']
+
+    # Numbers can be negative
+    >>> list(braceexpand('{2..-1}'))
+    ['2', '1', '0', '-1']
+
+    # Unbalanced braces raise an exception.
+    >>> list(braceexpand('{1{2,3}'))
+    Traceback (most recent call last):
+        ...
+    UnbalancedBracesError: Unbalanced braces: '{1{2,3}'
+
+    # By default, the backslash is the escape character.
+    >>> list(braceexpand(r'{1\\{2,3}'))
+    ['1{2', '3']
+
+    # Setting 'escape' to False disables backslash escaping.
+    >>> list(braceexpand(r'\\{1,2}', escape=False))
+    ['\\\\1', '\\\\2']
+
+    """
+    return (
+        escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape)
+    )
+
+
+def parse_pattern(pattern: str, escape: bool) -> Iterator[str]:
+    start = 0
+    pos = 0
+    bracketdepth = 0
+    items: list[Iterable[str]] = []
+
+    # print 'pattern:', pattern
+    while pos < len(pattern):
+        if escape and pattern[pos] == "\\":
+            pos += 2
+            continue
+        elif pattern[pos] == "{":
+            if bracketdepth == 0 and pos > start:
+                # print 'literal:', pattern[start:pos]
+                items.append([pattern[start:pos]])
+                start = pos
+            bracketdepth += 1
+        elif pattern[pos] == "}":
+            bracketdepth -= 1
+            if bracketdepth == 0:
+                # print 'expression:', pattern[start+1:pos]
+                expr = pattern[start + 1 : pos]
+                item = parse_expression(expr, escape)
+                if item is None:  # not a range or sequence
+                    items.extend([["{"], parse_pattern(expr, escape), ["}"]])
+                else:
+                    items.append(item)
+                start = pos + 1  # skip the closing brace
+        pos += 1
+
+    if bracketdepth != 0:  # unbalanced braces
+        raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern)
+
+    if start < pos:
+        items.append([pattern[start:]])
+
+    return ("".join(item) for item in product(*items))
+
+
+def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]:
+    int_range_match = int_range_re.match(expr)
+    if int_range_match:
+        return make_int_range(*int_range_match.groups())
+
+    char_range_match = char_range_re.match(expr)
+    if char_range_match:
+        return make_char_range(*char_range_match.groups())
+
+    return parse_sequence(expr, escape)
+
+
+def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]:
+    # sequence -> chain(*sequence_items)
+    start = 0
+    pos = 0
+    bracketdepth = 0
+    items: list[Iterable[str]] = []
+
+    # print 'sequence:', seq
+    while pos < len(seq):
+        if escape and seq[pos] == "\\":
+            pos += 2
+            continue
+        elif seq[pos] == "{":
+            bracketdepth += 1
+        elif seq[pos] == "}":
+            bracketdepth -= 1
+        elif seq[pos] == "," and bracketdepth == 0:
+            items.append(parse_pattern(seq[start:pos], escape))
+            start = pos + 1  # skip the comma
+        pos += 1
+
+    if bracketdepth != 0:
+        raise UnbalancedBracesError
+    if not items:
+        return None
+
+    # part after the last comma (may be the empty string)
+    items.append(parse_pattern(seq[start:], escape))
+    return chain(*items)
+
+
+def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]:
+    if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]):
+        padding = max(len(left), len(right))
+    else:
+        padding = 0
+    step = (int(incr) or 1) if incr else 1
+    start = int(left)
+    end = int(right)
+    r = range(start, end + 1, step) if start < end else range(start, end - 1, -step)
+    fmt = "%0{}d".format(padding)
+    return (fmt % i for i in r)
+
+
+def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str:
+    step = (int(incr) or 1) if incr else 1
+    start = alphabet.index(left)
+    end = alphabet.index(right)
+    if start < end:
+        return alphabet[start : end + 1 : step]
+    else:
+        end = end or -len(alphabet)
+        return alphabet[start : end - 1 : -step]
+
+
+if __name__ == "__main__":
+    import doctest
+    import sys
+
+    failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL)
+    if failed:
+        sys.exit(1)