Selaa lähdekoodia

Add llm exp code

Lengyue 2 vuotta sitten
vanhempi
commit
db347ca973
7 muutettua tiedostoa jossa 187 lisäystä ja 22 poistoa
  1. 34 0
      dockerfile
  2. 0 2
      requirements.txt
  3. 1 1
      speech_lm/configs/pretrain.yaml
  4. 1 1
      speech_lm/dataset.py
  5. 19 18
      speech_lm/train.py
  6. 99 0
      train.py
  7. 33 0
      train.sh

+ 34 - 0
dockerfile

@@ -0,0 +1,34 @@
+FROM nvcr.io/nvidia/pytorch:23.09-py3
+
+# Install system dependencies
+ENV DEBIAN_FRONTEND=noninteractive
+RUN apt-get update && apt-get install -y git curl build-essential ffmpeg libsm6 libxext6 libjpeg-dev \
+    zlib1g-dev aria2 zsh openssh-server sudo python3.10-venv && \
+    apt-get clean && rm -rf /var/lib/apt/lists/*
+
+# Install s5cmd
+RUN curl -L https://github.com/peak/s5cmd/releases/download/v2.1.0-beta.1/s5cmd_2.1.0-beta.1_Linux-64bit.tar.gz | tar xvz -C /tmp && \
+    mv /tmp/s5cmd /usr/local/bin/s5cmd && s5cmd --help
+
+# Install code server and zsh
+RUN wget -c https://github.com/coder/code-server/releases/download/v4.5.1/code-server_4.5.1_amd64.deb && \
+    dpkg -i ./code-server_4.5.1_amd64.deb && \
+    code-server --install-extension ms-python.python && \
+    rm ./code-server_4.5.1_amd64.deb && \
+    sh -c "$(curl https://raw.githubusercontent.com/robbyrussell/oh-my-zsh/master/tools/install.sh)" "" --unattended
+
+# Set zsh as default shell
+RUN chsh -s /usr/bin/zsh
+ENV SHELL=/usr/bin/zsh
+
+# Setup flash-attn
+RUN pip3 install --upgrade pip && \
+    pip3 install ninja packaging && \
+    MAX_JOBS=4 pip3 install flash-attn --no-build-isolation
+
+# Project Env
+WORKDIR /exp
+COPY requirements.txt .
+RUN pip3 install -r requirements.txt
+
+CMD /bin/zsh

+ 0 - 2
requirements.txt

@@ -6,5 +6,3 @@ lightning>=2.0.9.post0
 hydra-core>=1.3.2
 pyrootutils>=1.0.4
 tensorboard>=2.14.1
-ninja
-packaging

+ 1 - 1
speech_lm/configs/pretrain.yaml

@@ -32,7 +32,7 @@ tokenizer:
 # 3e12 / 1024 / 512 / 8 = 715255
 schedule:
   max_length: 1024
-  batch_size: 512
+  batch_size: 16
   max_steps: 715255
   save_every: 2000
 

+ 1 - 1
speech_lm/dataset.py

@@ -25,7 +25,7 @@ def encode(examples, tokenizer, max_length=512):
     )
     data["labels"] = data["input_ids"].clone()
     data["labels"][data["attention_mask"] == 0] = -100
-    print(data["input_ids"].shape)
+
     return data
 
 

+ 19 - 18
speech_lm/train.py

@@ -8,6 +8,7 @@ from lightning.fabric import Fabric
 from omegaconf import DictConfig, OmegaConf
 from tqdm import tqdm
 from transformers.utils import is_flash_attn_available
+from transformers import LlamaForCausalLM
 
 # Allow TF32 on Ampere GPUs
 torch.set_float32_matmul_precision("high")
@@ -24,11 +25,11 @@ log = RankedLogger(__name__, rank_zero_only=True)
 
 
 def train(
-    model,
-    optimizer,
-    scheduler,
-    dataloader,
-    global_step,
+    model: LlamaForCausalLM,
+    optimizer: torch.optim.Optimizer,
+    scheduler: torch.optim.lr_scheduler._LRScheduler,
+    dataloader: torch.utils.data.DataLoader,
+    global_step: int,
     fabric: Fabric,
     cfg: DictConfig,
 ):
@@ -37,16 +38,17 @@ def train(
 
     while global_step < cfg.schedule.max_steps:
         for batch in dataloader:
-            print(batch)
-            # batch = fabric.setup_batch(batch)
-            # loss = model(**batch).loss
-            # loss.backward()
-            # optimizer.step()
-            # scheduler.step()
-            # optimizer.zero_grad()
-            # global_step += 1
-            # bar.update(1)
-            # bar.set_postfix({"loss": loss.item()})
+            # Train loop
+            optimizer.zero_grad()
+            loss = model(**batch).loss
+            fabric.backward(loss)
+            optimizer.step()
+            scheduler.step()
+
+            fabric.log_dict({
+                "train/loss": loss,
+                "train/lr": optimizer.param_groups[0]["lr"],
+            }, step=global_step)
 
             global_step += 1
             bar.update(1)
@@ -71,15 +73,14 @@ def main(cfg: DictConfig):
     log.info(f"Config: \n{OmegaConf.to_yaml(cfg)}")
 
     if is_flash_attn_available() is False:
-        raise RuntimeError(
-            "Flash attention is not available, training will be aborted."
-        )
+        log.warning("Flash attention is not available, using default attention")
 
     fabric: Fabric = hydra.utils.instantiate(cfg.trainer)
     fabric.launch()
     log.info(f"Fabric: {fabric}")
 
     model = hydra.utils.instantiate(cfg.model)
+    log.info(f"Model: {repr(model)}")
 
     trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
     freeze_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)

+ 99 - 0
train.py

@@ -0,0 +1,99 @@
+from dataclasses import dataclass, field
+from functools import partial
+from typing import Optional
+from .speech_lm.dataset import build_dataset
+from datasets import load_dataset, load_from_disk
+from transformers import (
+    AutoModelForCausalLM,
+    AutoTokenizer,
+    DataCollatorWithPadding,
+    HfArgumentParser,
+    Trainer,
+)
+from transformers import TrainingArguments as _TrainingArguments
+
+
+@dataclass
+class ModelArguments:
+    model_name_or_path: Optional[str] = field(default="fishaudio/speech-lm-300m")
+    model_revision: Optional[str] = field(default="main")
+
+
+@dataclass
+class DataArguments:
+    pass
+
+@dataclass
+class TrainingArguments(_TrainingArguments):
+    cache_dir: Optional[str] = field(default=None)
+    optim: str = field(default="adamw_torch")
+    model_max_length: int = field(
+        default=512,
+        metadata={
+            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
+        },
+    )
+    use_lora: bool = field(default=False)
+
+
+def train():
+    parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
+    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+    model = AutoModelForCausalLM.from_pretrained(
+        model_args.model_name_or_path,
+        trust_remote_code=True,
+        cache_dir=training_args.cache_dir,
+        revision=model_args.model_revision,
+    )
+
+    tokenizer = AutoTokenizer.from_pretrained(
+        model_args.model_name_or_path,
+        use_fast=False,
+        trust_remote_code=True,
+        model_max_length=training_args.model_max_length,
+        cache_dir=training_args.cache_dir,
+        revision=model_args.model_revision,
+    )
+    tokenizer.pad_token_id = tokenizer.eos_token_id
+
+    if training_args.use_lora:
+        from peft import LoraConfig, TaskType, get_peft_model
+
+        peft_config = LoraConfig(
+            task_type=TaskType.CAUSAL_LM,
+            target_modules=["W_pack"],
+            inference_mode=False,
+            r=16,
+            lora_alpha=64,
+            lora_dropout=0.1,
+        )
+        model.enable_input_require_grads()
+        model = get_peft_model(model, peft_config)
+        model.print_trainable_parameters()
+
+    try:
+        dataset = load_from_disk(data_args.data_path)
+        if "train" in dataset:
+            dataset = dataset["train"]
+    except:
+        dataset = load_dataset(data_args.data_path, split="train")
+
+    dataset.set_transform(partial(dataset_transform, tokenizer=tokenizer))
+    dataset = dataset.train_test_split(test_size=1000, seed=42)
+
+    trainer = Trainer(
+        model=model,
+        args=training_args,
+        train_dataset=dataset["train"],
+        eval_dataset=dataset["test"],
+        tokenizer=tokenizer,
+        data_collator=DataCollatorWithPadding(tokenizer),
+    )
+    trainer.train()
+    trainer.save_state()
+    trainer.save_model(output_dir=training_args.output_dir)
+
+
+if __name__ == "__main__":
+    train()

+ 33 - 0
train.sh

@@ -0,0 +1,33 @@
+# export NCCL_P2P_DISABLE=1
+
+# hostfile=""
+# deepspeed --hostfile=$hostfile train.py \
+#     --deepspeed tools/tts/ds_config.json \
+#     --report_to "tensorboard" \
+#     --model_name_or_path "fishaudio/speech-lm-300m" \
+#     --model_revision "init" \
+#     --output_dir "results" \
+#     --model_max_length 4096 \
+#     --max_steps 500000 \
+#     --per_device_train_batch_size 32 \
+#     --gradient_accumulation_steps 1 \
+#     --save_strategy steps \
+#     --save_steps 10000 \
+#     --evaluation_strategy steps \
+#     --eval_steps 10000 \
+#     --learning_rate 1e-3 \
+#     --lr_scheduler_type cosine \
+#     --adam_beta1 0.9 \
+#     --adam_beta2 0.98 \
+#     --adam_epsilon 1e-8 \
+#     --max_grad_norm 1.0 \
+#     --weight_decay 1e-4 \
+#     --warmup_steps 10000 \
+#     --logging_steps 1 \
+#     --gradient_checkpointing True \
+#     --remove_unused_columns False \
+#     --use_lora False \
+#     --bf16 True \
+#     --tf32 True
+
+accelerate launch --config_file accelerate-config.yaml train_unconditional.py --config configs/svc.yaml --project test-simple-ne