Lengyue 2 лет назад
Родитель
Сommit
c06704fb03

+ 2 - 2
speech_lm/configs/llama_pretrain.yaml

@@ -12,7 +12,7 @@ trainer:
   strategy:
     _target_: lightning.fabric.strategies.DDPStrategy
     static_graph: true
-  num_nodes: 4
+  num_nodes: 8
   devices: 8
   precision: bf16-mixed
   loggers:
@@ -37,7 +37,7 @@ tokenizer:
 # This is a 300 billion seen token schedule
 schedule:
   max_length: 1024
-  batch_size: 128  # 128 * 4 = 512
+  batch_size: 64  # 128 * 4 = 512
   micro_batch_size: 8
   max_steps: 100000
   save_interval: 5000

+ 6 - 7
speech_lm/train.py

@@ -99,10 +99,9 @@ def train(
             model.train()
 
             # Accumulate gradients
+            gradient_accumulation_steps = cfg.schedule.gradient_accumulation_steps
+            is_accumulating = accumulate_steps % gradient_accumulation_steps != 0
             accumulate_steps += 1
-            is_accumulating = (
-                accumulate_steps < cfg.schedule.gradient_accumulation_steps
-            )
 
             # Train one step
             with fabric.no_backward_sync(model, enabled=is_accumulating):
@@ -111,7 +110,7 @@ def train(
                 metrics = getattr(outputs, "metrics", {})
 
                 # Need to divide loss by accumulation steps
-                fabric.backward(loss / cfg.schedule.gradient_accumulation_steps)
+                fabric.backward(loss / gradient_accumulation_steps)
 
                 # Update trackers
                 trackers["loss"].append(float(loss))
@@ -153,14 +152,14 @@ def train(
 
             fabric.log_dict(
                 {
-                    f"train/{k}": sum(v[-accumulate_steps:])
-                    / len(v[-accumulate_steps:])
+                    f"train/{k}": sum(v[-gradient_accumulation_steps:])
+                    / len(v[-gradient_accumulation_steps:])
                     for k, v in trackers.items()
                 },
                 step=global_step,
             )
 
-            accumulate_steps = 0
+            # accumulate_steps = 0
             global_step += 1
 
             if global_step % cfg.schedule.log_interval == 0:

+ 26 - 0
tools/extract_whisper_vq_weights.py

@@ -0,0 +1,26 @@
+from pathlib import Path
+
+import click
+import torch
+from loguru import logger
+
+
+@click.command()
+@click.argument(
+    "input-file",
+    type=click.Path(exists=True, dir_okay=False, file_okay=True, path_type=Path),
+)
+@click.argument(
+    "output-file",
+    type=click.Path(exists=False, dir_okay=False, file_okay=True, path_type=Path),
+)
+def extract(input_file: Path, output_file: Path):
+    model = torch.load(input_file, map_location="cpu")["model"]
+    state_dict = {k: v for k, v in model.items() if k.startswith("whisper") is False}
+
+    torch.save(state_dict, output_file)
+    logger.info(f"Saved {len(state_dict)} keys to {output_file}")
+
+
+if __name__ == "__main__":
+    extract()

+ 0 - 0
speech_lm/init_model.py → tools/init_llama_model.py