Browse Source

Update base training code

Lengyue 2 years ago
parent
commit
9adf687fb1

+ 13 - 7
preparing_data/wenet_clean/compress_tar.py

@@ -1,9 +1,10 @@
-import tarfile
-from pathlib import Path
-from tqdm import tqdm
 import io
 import random
+import tarfile
 from multiprocessing import Process
+from pathlib import Path
+
+from tqdm import tqdm
 
 
 def chunked_tarring(rank, file_list, base_folder, output_folder, chunk_size=1024**3):
@@ -25,7 +26,9 @@ def chunked_tarring(rank, file_list, base_folder, output_folder, chunk_size=1024
 
             # write the buffer to disk
             buffer.seek(0)
-            with open(output_folder / f"chunk-{rank:03d}-{chunk_count:04d}.tar", "wb") as f:
+            with open(
+                output_folder / f"chunk-{rank:03d}-{chunk_count:04d}.tar", "wb"
+            ) as f:
                 f.write(buffer.read())
 
             chunk_count += 1
@@ -42,14 +45,14 @@ def chunked_tarring(rank, file_list, base_folder, output_folder, chunk_size=1024
 
         if saved_count % 1000 == 0:
             print(f"Rank {rank}: {saved_count}/{len(file_list)}")
-        
+
         saved_count += 1
 
     tar.close()
     buffer.seek(0)
     with open(output_folder / f"chunk-{rank:03d}-{chunk_count:04d}.tar", "wb") as f:
         f.write(buffer.read())
-    
+
     print(f"Rank {rank}: {saved_count}/{len(file_list)}")
 
 
@@ -72,7 +75,10 @@ if __name__ == "__main__":
         if i == num_workers - 1:
             end = len(file_list)
 
-        p = Process(target=chunked_tarring, args=(i, file_list[start:end], base_folder, output_folder))
+        p = Process(
+            target=chunked_tarring,
+            args=(i, file_list[start:end], base_folder, output_folder),
+        )
         p.start()
         processes.append(p)
 

+ 60 - 0
speech_lm/configs/pretrain.yaml

@@ -1,6 +1,66 @@
 paths:
   run_dir: results/pretrain
+  checkpoint_dir: ${paths.run_dir}/checkpoints
 
 hydra:
   run:
     dir: ${paths.run_dir}
+
+trainer:
+  _target_: lightning.fabric.Fabric
+  accelerator: gpu
+  strategy: ddp
+  devices: auto
+  precision: bf16-mixed
+  loggers:
+    - _target_: pytorch_lightning.loggers.TensorBoardLogger
+      save_dir: ${paths.run_dir}
+      name: tensorboard
+      version: null
+
+model:
+  _target_: transformers.AutoModelForCausalLM.from_pretrained
+  pretrained_model_name_or_path: fishaudio/speech-lm-300m
+  revision: init
+
+tokenizer:
+  _target_: transformers.AutoTokenizer.from_pretrained
+  pretrained_model_name_or_path: fishaudio/speech-lm-300m
+  revision: init
+
+# Say we want a 3 trillion seen token schedule
+# 3e12 / 1024 / 512 / 8 = 715255
+schedule:
+  max_length: 1024
+  batch_size: 512
+  max_steps: 715255
+  save_every: 2000
+
+dataloader:
+  _target_: torch.utils.data.DataLoader
+  dataset: 
+    _target_: speech_lm.dataset.build_dataset
+    tokenizer: ${tokenizer}
+    max_length: ${schedule.max_length}
+  batch_size: ${schedule.batch_size}
+  num_workers: 4
+  collate_fn:
+    _target_: transformers.DataCollatorWithPadding
+    tokenizer: ${tokenizer}
+    max_length: ${schedule.max_length}
+
+optimizer:
+  _target_: torch.optim.AdamW
+  lr: 3e-4
+  weight_decay: 0.1
+  betas: [0.9, 0.95]
+  eps: 1e-5
+
+scheduler:
+  _target_: torch.optim.lr_scheduler.LambdaLR
+  lr_lambda:
+    _target_: speech_lm.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+    _partial_: true
+    num_warmup_steps: 2000
+    num_training_steps: ${schedule.max_steps}
+    final_lr_ratio: 0.1

+ 20 - 22
speech_lm/dataset.py

@@ -1,39 +1,35 @@
 import random
-from transformers import AutoTokenizer
-from datasets import load_dataset, interleave_datasets, IterableDataset
-from functools import lru_cache
-from torch.utils.data import DataLoader
+from functools import partial
+
+from datasets import IterableDataset, interleave_datasets, load_dataset
 from datasets.distributed import split_dataset_by_node
 from torch.distributed import get_rank, get_world_size, is_initialized
 
 
-@lru_cache(maxsize=1)
-def get_tokenizer():
-    return AutoTokenizer.from_pretrained("fishaudio/speech-lm-300m", revision="init")
-
-def encode(examples):
+def encode(examples, tokenizer, max_length=512):
     # Random choice a 512 token window for each example
     texts = []
     for text in examples["text"]:
-        if len(text) <= 512:
+        if len(text) <= max_length:
             texts.append(text)
         else:
-            start = random.randint(0, len(text) - 512)
-            texts.append(text[start : start + 512])
-    
-    data = get_tokenizer()(
-        texts, 
-        truncation=True, 
+            start = random.randint(0, len(text) - max_length)
+            texts.append(text[start : start + max_length])
+
+    data = tokenizer(
+        texts,
+        truncation=True,
         padding="max_length",
-        max_length=512,
+        max_length=max_length,
+        return_tensors="pt",
     )
-    data["labels"] = data["input_ids"].copy()
+    data["labels"] = data["input_ids"].clone()
     data["labels"][data["attention_mask"] == 0] = -100
-
+    print(data["input_ids"].shape)
     return data
 
 
-def build_dataset():
+def build_dataset(tokenizer, max_length=512):
     en_dataset = load_dataset("uonlp/CulturaX", "en", split="train", streaming=True)
     ja_dataset = load_dataset("uonlp/CulturaX", "ja", split="train", streaming=True)
     zh_dataset = load_dataset("uonlp/CulturaX", "zh", split="train", streaming=True)
@@ -47,13 +43,15 @@ def build_dataset():
         multilingual_dataset = split_dataset_by_node(
             multilingual_dataset,
             rank=get_rank(),
-            num_replicas=get_world_size(),
+            world_size=get_world_size(),
         )
 
     multilingual_dataset = multilingual_dataset.shuffle(seed=42, buffer_size=10000)
 
     multilingual_dataset = multilingual_dataset.map(
-        encode, batched=True, remove_columns=multilingual_dataset.column_names
+        partial(encode, tokenizer=tokenizer, max_length=max_length),
+        batched=True,
+        remove_columns=multilingual_dataset.column_names,
     )
 
     return multilingual_dataset

+ 1 - 1
speech_lm/init_model.py

@@ -1,4 +1,4 @@
-from transformers import LlamaModel, LlamaConfig, AutoTokenizer
+from transformers import AutoTokenizer, LlamaConfig, LlamaModel
 
 # reuse the tokenizer from the llama
 model_type = "meta-llama/Llama-2-7b-hf"

+ 51 - 0
speech_lm/logger.py

@@ -0,0 +1,51 @@
+import logging
+from typing import Mapping, Optional
+
+from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
+
+
+class RankedLogger(logging.LoggerAdapter):
+    """A multi-GPU-friendly python command line logger."""
+
+    def __init__(
+        self,
+        name: str = __name__,
+        rank_zero_only: bool = False,
+        extra: Optional[Mapping[str, object]] = None,
+    ) -> None:
+        """Initializes a multi-GPU-friendly python command line logger that logs on all processes
+        with their rank prefixed in the log message.
+
+        :param name: The name of the logger. Default is ``__name__``.
+        :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
+        :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
+        """
+        logger = logging.getLogger(name)
+        super().__init__(logger=logger, extra=extra)
+        self.rank_zero_only = rank_zero_only
+
+    def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None:
+        """Delegate a log call to the underlying logger, after prefixing its message with the rank
+        of the process it's being logged from. If `'rank'` is provided, then the log will only
+        occur on that rank/process.
+
+        :param level: The level to log at. Look at `logging.__init__.py` for more information.
+        :param msg: The message to log.
+        :param rank: The rank to log at.
+        :param args: Additional args to pass to the underlying logging function.
+        :param kwargs: Any additional keyword args to pass to the underlying logging function.
+        """
+        if self.isEnabledFor(level):
+            msg, kwargs = self.process(msg, kwargs)
+            current_rank = getattr(rank_zero_only, "rank", None)
+            if current_rank is None:
+                raise RuntimeError("The `rank_zero_only.rank` needs to be set before use")
+            msg = rank_prefixed_message(msg, current_rank)
+            if self.rank_zero_only:
+                if current_rank == 0:
+                    self.logger.log(level, msg, *args, **kwargs)
+            else:
+                if rank is None:
+                    self.logger.log(level, msg, *args, **kwargs)
+                elif current_rank == rank:
+                    self.logger.log(level, msg, *args, **kwargs)

+ 20 - 0
speech_lm/scheduler.py

@@ -0,0 +1,20 @@
+import math
+
+def get_cosine_schedule_with_warmup_lr_lambda(
+    current_step: int,
+    *,
+    num_warmup_steps: int,
+    num_training_steps: int,
+    num_cycles: float = 0.5,
+    final_lr_ratio: float = 0.0,
+):
+    if current_step < num_warmup_steps:
+        return float(current_step) / float(max(1, num_warmup_steps))
+
+    progress = float(current_step - num_warmup_steps) / float(
+        max(1, num_training_steps - num_warmup_steps)
+    )
+
+    return max(
+        final_lr_ratio, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
+    )

+ 116 - 4
speech_lm/train.py

@@ -1,8 +1,13 @@
+import logging
+from pathlib import Path
+
+import hydra
+import pyrootutils
 import torch
 from lightning.fabric import Fabric
-import hydra
 from omegaconf import DictConfig, OmegaConf
-import pyrootutils
+from tqdm import tqdm
+from transformers.utils import is_flash_attn_available
 
 # Allow TF32 on Ampere GPUs
 torch.set_float32_matmul_precision("high")
@@ -13,11 +18,118 @@ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 OmegaConf.register_new_resolver("eval", eval)
 
 # flake8: noqa: E402
-from speech_lm.dataset import build_dataset
+from speech_lm.logger import RankedLogger
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+def train(
+    model,
+    optimizer,
+    scheduler,
+    dataloader,
+    global_step,
+    fabric: Fabric,
+    cfg: DictConfig,
+):
+    bar = tqdm(total=cfg.schedule.max_steps, desc="Training")
+    bar.update(global_step)
+
+    while global_step < cfg.schedule.max_steps:
+        for batch in dataloader:
+            print(batch)
+            # batch = fabric.setup_batch(batch)
+            # loss = model(**batch).loss
+            # loss.backward()
+            # optimizer.step()
+            # scheduler.step()
+            # optimizer.zero_grad()
+            # global_step += 1
+            # bar.update(1)
+            # bar.set_postfix({"loss": loss.item()})
+
+            global_step += 1
+            bar.update(1)
+
+            if global_step % cfg.schedule.save_steps == 0:
+                fabric.save(
+                    Path(cfg.paths.checkpoint_dir) / f"step_{global_step}.ckpt",
+                    {
+                        "model": model,
+                        "optimizer": optimizer,
+                        "scheduler": scheduler,
+                        "global_step": global_step,
+                    },
+                )
+
+            if global_step >= cfg.schedule.max_steps:
+                break
+
 
 @hydra.main(version_base="1.3", config_path="./configs", config_name="pretrain.yaml")
 def main(cfg: DictConfig):
-    print(cfg)
+    log.info(f"Config: \n{OmegaConf.to_yaml(cfg)}")
+
+    if is_flash_attn_available() is False:
+        raise RuntimeError(
+            "Flash attention is not available, training will be aborted."
+        )
+
+    fabric: Fabric = hydra.utils.instantiate(cfg.trainer)
+    fabric.launch()
+    log.info(f"Fabric: {fabric}")
+
+    model = hydra.utils.instantiate(cfg.model)
+
+    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+    freeze_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
+    log.info(f"Trainable parameters: {trainable_params/1e6:.2f}M")
+    log.info(f"Freeze parameters: {freeze_params/1e6:.2f}M")
+
+    optimizer = hydra.utils.instantiate(cfg.optimizer, params=model.parameters())
+    scheduler = hydra.utils.instantiate(cfg.scheduler, optimizer=optimizer)
+    log.info(f"Optimizer: {optimizer}")
+    log.info(f"Scheduler: {scheduler}")
+
+    # Build state
+    global_step = 0
+
+    # Restore training from checkpoint
+    checkpoint_dir = Path(cfg.paths.checkpoint_dir)
+    checkpoint_dir.mkdir(parents=True, exist_ok=True)
+    checkpoint_path = checkpoint_dir / "last.ckpt"
+    if checkpoint_path.exists():
+        log.info(f"Restoring checkpoint from {checkpoint_path}")
+        remainder = fabric.load(
+            checkpoint_path,
+            {
+                "model": model,
+                "optimizer": optimizer,
+                "scheduler": scheduler,
+            },
+        )
+        global_step = remainder["global_step"]
+        log.info(f"Restored global step: {global_step}")
+
+    log.info(f"Setup fabric model & dataset")
+    model, optimizer, scheduler = fabric.setup(model, optimizer, scheduler)
+
+    train_dataloader = hydra.utils.instantiate(cfg.dataloader)
+    log.info(f"Dataloader: {train_dataloader}")
+
+    train_dataloader = fabric.setup_dataloaders(train_dataloader)
+    log.info(f"Begin training")
+
+    train(
+        model=model,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        dataloader=train_dataloader,
+        global_step=global_step,
+        fabric=fabric,
+        cfg=cfg,
+    )
+
 
 if __name__ == "__main__":
     main()