Ver Fonte

Update much robust dataloader

Lengyue há 2 anos atrás
pai
commit
7650b2ac43

+ 1 - 0
.gitignore

@@ -5,3 +5,4 @@ __pycache__
 /results
 /results
 /data
 /data
 /*.test.sh
 /*.test.sh
+*.filelist

+ 3 - 0
dockerfile

@@ -31,4 +31,7 @@ WORKDIR /exp
 COPY requirements.txt .
 COPY requirements.txt .
 RUN pip3 install -r requirements.txt && pip3 install encodec --no-deps
 RUN pip3 install -r requirements.txt && pip3 install encodec --no-deps
 
 
+COPY . .
+RUN pip3 install -e .
+
 CMD /bin/zsh
 CMD /bin/zsh

+ 0 - 1
requirements.txt

@@ -4,7 +4,6 @@ bitsandbytes>=0.41.1
 peft>=0.5.0
 peft>=0.5.0
 lightning>=2.0.9.post0
 lightning>=2.0.9.post0
 hydra-core>=1.3.2
 hydra-core>=1.3.2
-pyrootutils>=1.0.4
 tensorboard>=2.14.1
 tensorboard>=2.14.1
 natsort>=8.4.0
 natsort>=8.4.0
 einops>=0.7.0
 einops>=0.7.0

+ 7 - 0
setup.py

@@ -0,0 +1,7 @@
+from setuptools import find_packages, setup
+
+setup(
+    name="speech-lm",
+    version="0.0.1",
+    packages=find_packages(include=["speech_lm", "speech_lm.*"]),
+)

+ 16 - 5
speech_lm/configs/pretrain.yaml

@@ -39,16 +39,27 @@ schedule:
   gradient_accumulation_steps: "${eval: ${schedule.batch_size} // ${schedule.micro_batch_size}}"
   gradient_accumulation_steps: "${eval: ${schedule.batch_size} // ${schedule.micro_batch_size}}"
   clip_grad_norm: 1.0
   clip_grad_norm: 1.0
 
 
+dataset:
+  _target_: datasets.interleave_datasets
+  datasets:
+    - _target_: speech_lm.datasets.cultura_x.CulturaXDataset
+      lang: 'en'
+    - _target_: speech_lm.datasets.cultura_x.CulturaXDataset
+      lang: 'zh'
+    - _target_: speech_lm.datasets.cultura_x.CulturaXDataset
+      lang: 'ja'
+  probabilities: [0.4, 0.3, 0.3]
+  seed: 42
+
 dataloader:
 dataloader:
   _target_: torch.utils.data.DataLoader
   _target_: torch.utils.data.DataLoader
-  dataset: 
-    _target_: speech_lm.dataset.build_dataset
-    tokenizer: ${tokenizer}
-    max_length: ${schedule.max_length}
+  dataset: ${dataset}
   batch_size: ${schedule.micro_batch_size}
   batch_size: ${schedule.micro_batch_size}
   num_workers: 4
   num_workers: 4
   collate_fn:
   collate_fn:
-    _target_: transformers.DefaultDataCollator
+    _target_: speech_lm.datasets.cultura_x.CulutreXCollator
+    tokenizer: ${tokenizer}
+    max_length: ${schedule.max_length}
 
 
 optimizer:
 optimizer:
   _target_: torch.optim.AdamW
   _target_: torch.optim.AdamW

+ 0 - 62
speech_lm/dataset.py

@@ -1,62 +0,0 @@
-import random
-from functools import partial
-
-from datasets import IterableDataset, interleave_datasets, load_dataset
-from datasets.distributed import split_dataset_by_node
-from torch.distributed import get_rank, get_world_size, is_initialized
-
-
-def encode(examples, tokenizer, max_length=512):
-    # Random choice a 512 token window for each example
-    texts = []
-    for text in examples["text"]:
-        if len(text) <= max_length:
-            texts.append(text)
-        else:
-            start = random.randint(0, len(text) - max_length)
-            texts.append(text[start : start + max_length])
-
-    data = tokenizer(
-        texts,
-        truncation=True,
-        padding="max_length",
-        max_length=max_length,
-        return_tensors="pt",
-    )
-    data["labels"] = data["input_ids"].clone()
-    data["labels"][data["attention_mask"] == 0] = -100
-
-    return data
-
-
-def build_dataset(tokenizer, max_length=512):
-    en_dataset = load_dataset("uonlp/CulturaX", "en", split="train", streaming=True)
-    ja_dataset = load_dataset("uonlp/CulturaX", "ja", split="train", streaming=True)
-    zh_dataset = load_dataset("uonlp/CulturaX", "zh", split="train", streaming=True)
-
-    multilingual_dataset: IterableDataset = interleave_datasets(
-        [en_dataset, ja_dataset, zh_dataset], probabilities=[0.4, 0.3, 0.3], seed=42
-    )
-
-    # DDP
-    if is_initialized():
-        multilingual_dataset = split_dataset_by_node(
-            multilingual_dataset,
-            rank=get_rank(),
-            world_size=get_world_size(),
-        )
-
-    multilingual_dataset = multilingual_dataset.shuffle(seed=42, buffer_size=10000)
-
-    multilingual_dataset = multilingual_dataset.map(
-        partial(encode, tokenizer=tokenizer, max_length=max_length),
-        batched=True,
-        remove_columns=multilingual_dataset.column_names,
-    )
-
-    return multilingual_dataset
-
-
-if __name__ == "__main__":
-    dataset = build_dataset()
-    print(list(dataset.take(16)))

+ 124 - 0
speech_lm/datasets/cultura_x.py

@@ -0,0 +1,124 @@
+from dataclasses import dataclass
+import random
+import pandas as pd
+from speech_lm.utils.braceexpand import braceexpand
+from torch.utils.data import IterableDataset, get_worker_info
+from torch.distributed import get_rank, get_world_size, is_initialized
+from random import Random
+from logging import getLogger
+from huggingface_hub import hf_hub_download
+from transformers import AutoTokenizer
+
+SUBSETS = {
+    "en": "en_part_{00000..03071}",
+    "zh": "zh_part_{00000..00319}",
+    "ja": "ja_part_{00000..00159}",
+}
+
+log = getLogger(__name__)
+
+
+class CulturaXDataset(IterableDataset):
+    def __init__(self, lang: str, seed: int = 42):
+        super().__init__()
+
+        self.lang = lang
+        self.seed = seed
+
+        # Get sharded files
+        self.files = sorted(list(braceexpand(f"{lang}/{SUBSETS[lang]}.parquet")))
+        Random(seed).shuffle(self.files)
+
+    def get_data_splits(self, files):
+        # We need to know the total number of devices
+        # to split the data properly
+
+        total_devices = 1
+        if is_initialized():
+            total_devices = get_world_size()
+        
+        worker_info = get_worker_info()
+        if worker_info is not None:
+            total_devices *= worker_info.num_workers
+        
+        if len(files) < total_devices:
+            # Repeat the files N times to match the number of devices
+            files = files * (total_devices // len(files) + 1)
+
+        # DDP
+        if is_initialized():
+            files = files[get_rank() :: get_world_size()]
+
+        # Split by worker
+        if worker_info is not None:
+            files = files[worker_info.id :: worker_info.num_workers]
+
+        return files
+
+    def __iter__(self):
+        files = self.get_data_splits(self.files)
+        random.shuffle(files)
+
+        for filename in files:
+            yield from self.parse_data(filename)
+
+    def parse_data(self, filename: str):
+        fname = hf_hub_download(
+            "uonlp/CulturaX",
+            filename,
+            repo_type="dataset",
+        )
+
+        # Read the file
+        df = pd.read_parquet(fname)
+
+        # Shuffle the data
+        df = df.sample(frac=1.0)
+
+        # Yield the data
+        for text in df["text"]:
+            yield {"text": text}
+
+
+@dataclass
+class CulutreXCollator:
+    tokenizer: AutoTokenizer
+    max_length: int = 512
+
+    def __call__(self, examples):
+        texts = []
+
+        for example in examples:
+            text = example["text"]
+
+            if len(text) <= self.max_length:
+                texts.append(text)
+            else:
+                start = random.randint(0, len(text) - self.max_length)
+                texts.append(text[start : start + self.max_length])
+
+        if self.tokenizer.pad_token is None:
+            self.tokenizer.pad_token = self.tokenizer.eos_token
+
+        data = self.tokenizer(
+            texts,
+            truncation=True,
+            padding=True,
+            max_length=self.max_length,
+            return_tensors="pt",
+        )
+
+        data["labels"] = data["input_ids"].clone()
+        data["labels"][data["attention_mask"] == 0] = -100
+
+        return data
+
+
+if __name__ == "__main__":
+    from torch.utils.data import DataLoader
+
+    dataset = CulturaXDataset("en")
+    collator = CulutreXCollator(AutoTokenizer.from_pretrained("gpt2"))
+
+    for batch in DataLoader(dataset, batch_size=4, collate_fn=collator, num_workers=4):
+        print(batch)

+ 4 - 7
speech_lm/train.py

@@ -1,26 +1,23 @@
 from pathlib import Path
 from pathlib import Path
 
 
 import hydra
 import hydra
-import pyrootutils
 import torch
 import torch
 from lightning.fabric import Fabric
 from lightning.fabric import Fabric
+from natsort import natsorted
 from omegaconf import DictConfig, OmegaConf
 from omegaconf import DictConfig, OmegaConf
 from tqdm import tqdm
 from tqdm import tqdm
 from transformers import LlamaForCausalLM
 from transformers import LlamaForCausalLM
 from transformers.utils import is_flash_attn_available
 from transformers.utils import is_flash_attn_available
-from natsort import natsorted
+
+from speech_lm.logger import RankedLogger
 
 
 # Allow TF32 on Ampere GPUs
 # Allow TF32 on Ampere GPUs
 torch.set_float32_matmul_precision("high")
 torch.set_float32_matmul_precision("high")
 torch.backends.cudnn.allow_tf32 = True
 torch.backends.cudnn.allow_tf32 = True
 
 
-# register eval resolver and root
-pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+# register eval resolver
 OmegaConf.register_new_resolver("eval", eval)
 OmegaConf.register_new_resolver("eval", eval)
 
 
-# flake8: noqa: E402
-from speech_lm.logger import RankedLogger
-
 log = RankedLogger(__name__, rank_zero_only=True)
 log = RankedLogger(__name__, rank_zero_only=True)
 
 
 
 

+ 217 - 0
speech_lm/utils/braceexpand.py

@@ -0,0 +1,217 @@
+"""
+Bash-style brace expansion
+Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py
+License: MIT
+"""
+
+import re
+import string
+from itertools import chain, product
+from typing import Iterable, Iterator, Optional
+
+__all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"]
+
+
+class UnbalancedBracesError(ValueError):
+    pass
+
+
+alphabet = string.ascii_uppercase + string.ascii_lowercase
+
+int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$")
+char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$")
+escape_re = re.compile(r"\\(.)")
+
+
+def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]:
+    """braceexpand(pattern) -> iterator over generated strings
+
+    Returns an iterator over the strings resulting from brace expansion
+    of pattern. This function implements Brace Expansion as described in
+    bash(1), with the following limitations:
+
+    * A pattern containing unbalanced braces will raise an
+      UnbalancedBracesError exception. In bash, unbalanced braces will either
+      be partly expanded or ignored.
+
+    * A mixed-case character range like '{Z..a}' or '{a..Z}' will not
+      include the characters '[]^_`' between 'Z' and 'a'.
+
+    When escape is True (the default), characters in pattern can be
+    prefixed with a backslash to cause them not to be interpreted as
+    special characters for brace expansion (such as '{', '}', ',').
+    To pass through a a literal backslash, double it ('\\\\').
+
+    When escape is False, backslashes in pattern have no special
+    meaning and will be preserved in the output.
+
+    Examples:
+
+    >>> from braceexpand import braceexpand
+
+    # Integer range
+    >>> list(braceexpand('item{1..3}'))
+    ['item1', 'item2', 'item3']
+
+    # Character range
+    >>> list(braceexpand('{a..c}'))
+    ['a', 'b', 'c']
+
+    # Sequence
+    >>> list(braceexpand('index.html{,.backup}'))
+    ['index.html', 'index.html.backup']
+
+    # Nested patterns
+    >>> list(braceexpand('python{2.{5..7},3.{2,3}}'))
+    ['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3']
+
+    # Prefixing an integer with zero causes all numbers to be padded to
+    # the same width.
+    >>> list(braceexpand('{07..10}'))
+    ['07', '08', '09', '10']
+
+    # An optional increment can be specified for ranges.
+    >>> list(braceexpand('{a..g..2}'))
+    ['a', 'c', 'e', 'g']
+
+    # Ranges can go in both directions.
+    >>> list(braceexpand('{4..1}'))
+    ['4', '3', '2', '1']
+
+    # Numbers can be negative
+    >>> list(braceexpand('{2..-1}'))
+    ['2', '1', '0', '-1']
+
+    # Unbalanced braces raise an exception.
+    >>> list(braceexpand('{1{2,3}'))
+    Traceback (most recent call last):
+        ...
+    UnbalancedBracesError: Unbalanced braces: '{1{2,3}'
+
+    # By default, the backslash is the escape character.
+    >>> list(braceexpand(r'{1\\{2,3}'))
+    ['1{2', '3']
+
+    # Setting 'escape' to False disables backslash escaping.
+    >>> list(braceexpand(r'\\{1,2}', escape=False))
+    ['\\\\1', '\\\\2']
+
+    """
+    return (
+        escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape)
+    )
+
+
+def parse_pattern(pattern: str, escape: bool) -> Iterator[str]:
+    start = 0
+    pos = 0
+    bracketdepth = 0
+    items: list[Iterable[str]] = []
+
+    # print 'pattern:', pattern
+    while pos < len(pattern):
+        if escape and pattern[pos] == "\\":
+            pos += 2
+            continue
+        elif pattern[pos] == "{":
+            if bracketdepth == 0 and pos > start:
+                # print 'literal:', pattern[start:pos]
+                items.append([pattern[start:pos]])
+                start = pos
+            bracketdepth += 1
+        elif pattern[pos] == "}":
+            bracketdepth -= 1
+            if bracketdepth == 0:
+                # print 'expression:', pattern[start+1:pos]
+                expr = pattern[start + 1 : pos]
+                item = parse_expression(expr, escape)
+                if item is None:  # not a range or sequence
+                    items.extend([["{"], parse_pattern(expr, escape), ["}"]])
+                else:
+                    items.append(item)
+                start = pos + 1  # skip the closing brace
+        pos += 1
+
+    if bracketdepth != 0:  # unbalanced braces
+        raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern)
+
+    if start < pos:
+        items.append([pattern[start:]])
+
+    return ("".join(item) for item in product(*items))
+
+
+def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]:
+    int_range_match = int_range_re.match(expr)
+    if int_range_match:
+        return make_int_range(*int_range_match.groups())
+
+    char_range_match = char_range_re.match(expr)
+    if char_range_match:
+        return make_char_range(*char_range_match.groups())
+
+    return parse_sequence(expr, escape)
+
+
+def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]:
+    # sequence -> chain(*sequence_items)
+    start = 0
+    pos = 0
+    bracketdepth = 0
+    items: list[Iterable[str]] = []
+
+    # print 'sequence:', seq
+    while pos < len(seq):
+        if escape and seq[pos] == "\\":
+            pos += 2
+            continue
+        elif seq[pos] == "{":
+            bracketdepth += 1
+        elif seq[pos] == "}":
+            bracketdepth -= 1
+        elif seq[pos] == "," and bracketdepth == 0:
+            items.append(parse_pattern(seq[start:pos], escape))
+            start = pos + 1  # skip the comma
+        pos += 1
+
+    if bracketdepth != 0:
+        raise UnbalancedBracesError
+    if not items:
+        return None
+
+    # part after the last comma (may be the empty string)
+    items.append(parse_pattern(seq[start:], escape))
+    return chain(*items)
+
+
+def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]:
+    if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]):
+        padding = max(len(left), len(right))
+    else:
+        padding = 0
+    step = (int(incr) or 1) if incr else 1
+    start = int(left)
+    end = int(right)
+    r = range(start, end + 1, step) if start < end else range(start, end - 1, -step)
+    fmt = "%0{}d".format(padding)
+    return (fmt % i for i in r)
+
+
+def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str:
+    step = (int(incr) or 1) if incr else 1
+    start = alphabet.index(left)
+    end = alphabet.index(right)
+    if start < end:
+        return alphabet[start : end + 1 : step]
+    else:
+        end = end or -len(alphabet)
+        return alphabet[start : end - 1 : -step]
+
+
+if __name__ == "__main__":
+    import doctest
+    import sys
+
+    failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL)
+    if failed:
+        sys.exit(1)