浏览代码

Add llama finetune command

Lengyue 2 年之前
父节点
当前提交
423e1c0cfa
共有 2 个文件被更改,包括 138 次插入14 次删除
  1. 116 0
      speech_lm/configs/llama_finetune.yaml
  2. 22 14
      speech_lm/datasets/cultura_x.py

+ 116 - 0
speech_lm/configs/llama_finetune.yaml

@@ -0,0 +1,116 @@
+paths:
+  run_dir: results/finetune
+  checkpoint_dir: ${paths.run_dir}/checkpoints
+
+hydra:
+  run:
+    dir: ${paths.run_dir}
+
+trainer:
+  _target_: lightning.fabric.Fabric
+  accelerator: gpu
+  strategy:
+    _target_: lightning.fabric.strategies.DDPStrategy
+    static_graph: true
+  num_nodes: 1
+  devices: 8
+  precision: bf16-mixed
+  loggers:
+    _target_: pytorch_lightning.loggers.TensorBoardLogger
+    save_dir: ${paths.run_dir}
+    name: tensorboard
+    version: null
+
+model:
+  _target_: transformers.AutoModelForCausalLM.from_pretrained
+  pretrained_model_name_or_path: fishaudio/speech-lm-300m
+  revision: text-pretrain-10k
+
+tokenizer:
+  _target_: transformers.AutoTokenizer.from_pretrained
+  pretrained_model_name_or_path: fishaudio/speech-lm-300m
+  revision: text-pretrain-10k
+
+# This is a 200 billion seen token schedule
+schedule:
+  max_length: 1024
+  batch_size: 16  # 128 * 4 = 512
+  micro_batch_size: 8
+  max_steps: 100000
+  save_interval: 5000
+  log_interval: 10
+  gradient_accumulation_steps: "${eval: ${schedule.batch_size} // ${schedule.micro_batch_size}}"
+  clip_grad_norm: 1.0
+
+train_dataset:
+  _target_: speech_lm.datasets.cultura_x.InterleaveDataset
+  datasets:
+    - _target_: speech_lm.datasets.cultura_x.CulturaXDataset
+      lang: 'en'
+    - _target_: speech_lm.datasets.cultura_x.CulturaXDataset
+      lang: 'zh'
+    - _target_: speech_lm.datasets.cultura_x.CulturaXDataset
+      lang: 'ja'
+    - _target_: speech_lm.datasets.cultura_x.CulturaXDataset
+      repo: fishaudio/wenet-vq
+      files:
+        - data/train-00000-of-00018-b5a82c6054c6acca.parquet
+        - data/train-00001-of-00018-82467b3e0669c2be.parquet
+        - data/train-00002-of-00018-d50ed8c218a1f183.parquet
+        - data/train-00003-of-00018-15d666053eade100.parquet
+        - data/train-00004-of-00018-01868cb8408e012b.parquet
+        - data/train-00005-of-00018-e766a0b54b1fd08b.parquet
+        - data/train-00006-of-00018-c79fad54ea8a0b8d.parquet
+        - data/train-00007-of-00018-e4155011a7081a1d.parquet
+        - data/train-00008-of-00018-8ba319f5af359d15.parquet
+        - data/train-00009-of-00018-9c9e984a6565b2c3.parquet
+        - data/train-00010-of-00018-7af80a80e5aa1e54.parquet
+        - data/train-00011-of-00018-2ab91221787a84a3.parquet
+        - data/train-00012-of-00018-4d477812eea5d298.parquet
+        - data/train-00013-of-00018-faf87b68b1ab4a15.parquet
+        - data/train-00014-of-00018-7f6bbd9bcb4cbb55.parquet
+        - data/train-00015-of-00018-d630fe4a488b9f51.parquet
+        - data/train-00016-of-00018-969a4d5dc04d2764.parquet
+        - data/train-00017-of-00018-bbfd09175809d1fe.parquet
+  probabilities: [0.2, 0.2, 0.2, 0.4]
+  seed: 42
+
+train_dataloader:
+  _target_: torch.utils.data.DataLoader
+  dataset: ${train_dataset}
+  batch_size: ${schedule.micro_batch_size}
+  num_workers: 8
+  collate_fn:
+    _target_: speech_lm.datasets.cultura_x.CulutreXCollator
+    tokenizer: ${tokenizer}
+    max_length: ${schedule.max_length}
+
+valid_dataloader:
+  _target_: torch.utils.data.DataLoader
+  dataset:
+    _target_: speech_lm.datasets.cultura_x.CulturaXDataset
+    repo: fishaudio/wenet-vq
+    files:
+      - data/test-00000-of-00001-685250c116f5d321.parquet
+  batch_size: ${schedule.micro_batch_size}
+  num_workers: 1
+  collate_fn:
+    _target_: speech_lm.datasets.cultura_x.CulutreXCollator
+    tokenizer: ${tokenizer}
+    max_length: ${schedule.max_length}
+
+optimizer:
+  _target_: torch.optim.AdamW
+  lr: 1e-4
+  weight_decay: 0.1
+  betas: [0.9, 0.95]
+  eps: 1e-5
+
+scheduler:
+  _target_: torch.optim.lr_scheduler.LambdaLR
+  lr_lambda:
+    _target_: speech_lm.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+    _partial_: true
+    num_warmup_steps: 2000
+    num_training_steps: ${schedule.max_steps}
+    final_lr_ratio: 0.1

+ 22 - 14
speech_lm/datasets/cultura_x.py

@@ -2,6 +2,7 @@ import random
 from dataclasses import dataclass
 from logging import getLogger
 from random import Random
+from typing import Optional
 
 import numpy as np
 import pandas as pd
@@ -23,15 +24,28 @@ log = getLogger(__name__)
 
 
 class CulturaXDataset(IterableDataset):
-    def __init__(self, lang: str, seed: int = 42, parquet_batch_size: int = 10000):
+    def __init__(
+        self,
+        lang: Optional[str] = None,
+        seed: int = 42,
+        parquet_batch_size: int = 10000,
+        repo: str = "uonlp/CulturaX",
+        files: Optional[list[str]] = None,
+    ):
         super().__init__()
 
         self.lang = lang
         self.seed = seed
         self.parquet_batch_size = parquet_batch_size
+        self.repo = repo
+
+        if self.lang is not None:
+            files = sorted(list(braceexpand(f"{lang}/{SUBSETS[lang]}.parquet")))
+        else:
+            files = list(files)
 
         # Get sharded files
-        self.files = sorted(list(braceexpand(f"{lang}/{SUBSETS[lang]}.parquet")))
+        self.files = files
         Random(seed).shuffle(self.files)
 
     def get_data_splits(self, files):
@@ -71,7 +85,7 @@ class CulturaXDataset(IterableDataset):
                 log.exception(f"Failed to parse {filename}: {e}")
 
     def parse_data(self, filename: str):
-        url = f"https://huggingface.co/datasets/uonlp/CulturaX/resolve/main/{filename}"
+        url = f"https://huggingface.co/datasets/{self.repo}/resolve/main/{filename}"
 
         with xopen(url, mode="rb") as stream:
             parquet_file = pq.ParquetFile(stream)
@@ -91,16 +105,7 @@ class CulutreXCollator:
     max_length: int = 512
 
     def __call__(self, examples):
-        texts = []
-
-        for example in examples:
-            text = example["text"]
-
-            if len(text) <= self.max_length:
-                texts.append(text)
-            else:
-                start = random.randint(0, len(text) - self.max_length)
-                texts.append(text[start : start + self.max_length])
+        texts = [i["text"] for i in examples]
 
         if self.tokenizer.pad_token is None:
             self.tokenizer.pad_token = self.tokenizer.eos_token
@@ -152,9 +157,12 @@ class InterleaveDataset(IterableDataset):
 if __name__ == "__main__":
     from torch.utils.data import DataLoader
 
+    from speech_lm.datasets.wenet_vq import WenetVQDataset
+
     dataset_en = CulturaXDataset("en")
     dataset_ja = CulturaXDataset("ja")
-    dataset = InterleaveDataset([dataset_en, dataset_ja], [0.5, 0.5])
+    dataset_wenet = WenetVQDataset()
+    dataset = InterleaveDataset([dataset_en, dataset_wenet], [0.5, 0.5])
     collator = CulutreXCollator(AutoTokenizer.from_pretrained("gpt2"))
 
     for batch in DataLoader(dataset, batch_size=4, collate_fn=collator, num_workers=4):