Lengyue пре 2 година
родитељ
комит
3297b2dc2c

+ 2 - 0
requirements.txt

@@ -6,3 +6,5 @@ lightning>=2.0.9.post0
 hydra-core>=1.3.2
 pyrootutils>=1.0.4
 tensorboard>=2.14.1
+librosa
+encodec

+ 63 - 0
speech_lm/configs/hubert_vq.yaml

@@ -0,0 +1,63 @@
+paths:
+  run_dir: results/pretrain
+  checkpoint_dir: ${paths.run_dir}/checkpoints
+
+hydra:
+  run:
+    dir: ${paths.run_dir}
+
+trainer:
+  _target_: lightning.fabric.Fabric
+  accelerator: gpu
+  strategy: ddp
+  devices: auto
+  precision: bf16-mixed
+  loggers:
+    _target_: pytorch_lightning.loggers.TensorBoardLogger
+    save_dir: ${paths.run_dir}
+    name: tensorboard
+    version: null
+
+model:
+  _target_: speech_lm.models.hubert_vq.HubertVQDistill
+  model_name_or_path: facebook/hubert-large-ls960-ft
+  vq_layer: -4
+  codebook_size: 4096
+  trainable_layers_before_vq: 2
+  trainable_layers_after_vq: 2
+  vq_loss_weight: 1.0
+
+schedule:
+  batch_size: 128
+  micro_batch_size: 128
+  max_steps: 100000
+  save_interval: 2000
+  gradient_accumulation_steps: "${eval: ${schedule.batch_size} // ${schedule.micro_batch_size}}"
+  clip_grad_norm: 1.0
+
+dataset:
+  _target_: speech_lm.datasets.hubert_vq.HubertVQDataset
+  filelist: libritts-r.filelist
+
+dataloader:
+  _target_: torch.utils.data.DataLoader
+  dataset: ${dataset}
+  batch_size: ${schedule.micro_batch_size}
+  num_workers: 4
+  collate_fn:
+    _target_: speech_lm.datasets.hubert_vq.HubertVQCollator
+
+optimizer:
+  _target_: torch.optim.AdamW
+  lr: 3e-4
+  weight_decay: 0.1
+  betas: [0.9, 0.95]
+  eps: 1e-5
+
+scheduler:
+  _target_: torch.optim.lr_scheduler.LambdaLR
+  lr_lambda:
+    _target_: speech_lm.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+    _partial_: true
+    num_warmup_steps: 2000
+    num_training_steps: ${schedule.max_steps}

+ 25 - 20
speech_lm/datasets/hubert_vq.py

@@ -1,11 +1,12 @@
-import librosa
-from torch.utils.data import Dataset
 from pathlib import Path
+
+import librosa
 import torch
+from torch.utils.data import Dataset
 
 
 class HubertVQDataset(Dataset):
-    def __init__(self, filelist):
+    def __init__(self, filelist: str):
         super().__init__()
 
         self.files = Path(filelist).read_text().splitlines()
@@ -20,35 +21,39 @@ class HubertVQDataset(Dataset):
         return wav
 
 
-def collate_fn(batch):
-    # -> {"input_values": ..., "attention_mask": ...}
-    max_length = max([len(x) for x in batch])
+class HubertVQCollator:
+    @staticmethod
+    def __call__(batch):
+        # -> {"input_values": ..., "attention_mask": ...}
+        max_length = max([len(x) for x in batch])
 
-    input_values = []
-    attention_mask = []
+        input_values = []
+        attention_mask = []
 
-    for x in batch:
-        x_length = len(x)
-        x = torch.nn.functional.pad(x, (0, max_length - x_length))
-        mask = torch.ones_like(x)
-        mask[x_length:] = 0
+        for x in batch:
+            x_length = len(x)
+            x = torch.nn.functional.pad(x, (0, max_length - x_length))
+            mask = torch.ones_like(x)
+            mask[x_length:] = 0
 
-        input_values.append(x)
-        attention_mask.append(mask)
+            input_values.append(x)
+            attention_mask.append(mask)
 
-    input_values = torch.stack(input_values)
-    attention_mask = torch.stack(attention_mask)
+        input_values = torch.stack(input_values)
+        attention_mask = torch.stack(attention_mask)
 
-    return {"input_values": input_values, "attention_mask": attention_mask}
+        return {"input_values": input_values, "attention_mask": attention_mask}
 
 
 if __name__ == "__main__":
+    import soundfile as sf
     from torch.utils.data import DataLoader
     from transformers import HubertForCTC, Wav2Vec2Processor
-    import soundfile as sf
 
     dataset = HubertVQDataset("libritts-r.filelist")
-    dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
+    dataloader = DataLoader(
+        dataset, batch_size=16, shuffle=True, collate_fn=HubertVQCollator()
+    )
     hubert = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft")
     processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
     hubert.eval()

+ 6 - 2
speech_lm/logger.py

@@ -24,7 +24,9 @@ class RankedLogger(logging.LoggerAdapter):
         super().__init__(logger=logger, extra=extra)
         self.rank_zero_only = rank_zero_only
 
-    def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None:
+    def log(
+        self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
+    ) -> None:
         """Delegate a log call to the underlying logger, after prefixing its message with the rank
         of the process it's being logged from. If `'rank'` is provided, then the log will only
         occur on that rank/process.
@@ -39,7 +41,9 @@ class RankedLogger(logging.LoggerAdapter):
             msg, kwargs = self.process(msg, kwargs)
             current_rank = getattr(rank_zero_only, "rank", None)
             if current_rank is None:
-                raise RuntimeError("The `rank_zero_only.rank` needs to be set before use")
+                raise RuntimeError(
+                    "The `rank_zero_only.rank` needs to be set before use"
+                )
             msg = rank_prefixed_message(msg, current_rank)
             if self.rank_zero_only:
                 if current_rank == 0:

+ 72 - 4
speech_lm/models/hubert_vq.py

@@ -1,8 +1,10 @@
+from dataclasses import dataclass
 from typing import Optional
-from transformers import HubertModel
-from torch import nn
+
 import torch
 from encodec.quantization.core_vq import VectorQuantization
+from torch import nn
+from transformers import HubertModel
 
 
 class HubertVQ(nn.Module):
@@ -177,10 +179,76 @@ class HubertVQ(nn.Module):
         return hidden_states, vq_loss
 
 
-# class HubertVQ
+@dataclass
+class HubertVQOutput:
+    loss: torch.Tensor
+    metrics: dict[str, torch.Tensor]
+
+
+class HubertVQDistill(nn.Module):
+    def __init__(
+        self,
+        model_name_or_path: str = "facebook/hubert-large-ls960-ft",
+        vq_layer: int = -4,  # the layer to extract the quantized features
+        codebook_size: int = 1024,
+        trainable_layers_before_vq: int = 2,
+        trainable_layers_after_vq: int = 2,
+        vq_loss_weight: float = 1.0,
+    ):
+        super().__init__()
+
+        self.hubert_vq = HubertVQ(
+            model_name_or_path=model_name_or_path,
+            vq_layer=vq_layer,
+            codebook_size=codebook_size,
+            trainable_layers_before_vq=trainable_layers_before_vq,
+            trainable_layers_after_vq=trainable_layers_after_vq,
+        )
+
+        self.hubert_teacher = HubertModel.from_pretrained(model_name_or_path)
+        self.vq_loss_weight = vq_loss_weight
+
+        # Freeze teacher
+        for param in self.hubert_teacher.parameters():
+            param.requires_grad = False
+
+    def forward(
+        self,
+        input_values: Optional[torch.Tensor],
+        attention_mask: Optional[torch.Tensor] = None,
+        mask_time_indices: Optional[torch.FloatTensor] = None,
+    ) -> HubertVQOutput:
+        hidden_states, vq_loss = self.hubert_vq(
+            input_values,
+            attention_mask=attention_mask,
+            mask_time_indices=mask_time_indices,
+        )
+
+        # Teacher
+        with torch.no_grad():
+            teacher_hidden_states = self.hubert_teacher(
+                input_values,
+                attention_mask=attention_mask,
+                mask_time_indices=mask_time_indices,
+            ).last_hidden_state
+
+        distill_loss = torch.nn.functional.mse_loss(
+            hidden_states, teacher_hidden_states
+        )
+
+        loss = distill_loss + vq_loss * self.vq_loss_weight
+
+        metrics = {
+            "distill_loss": distill_loss,
+            "vq_loss": vq_loss,
+        }
+
+        return HubertVQOutput(loss=loss, metrics=metrics)
+
+
 if __name__ == "__main__":
-    from transformers import Wav2Vec2Tokenizer
     from datasets import load_dataset
+    from transformers import Wav2Vec2Tokenizer
 
     processor = Wav2Vec2Tokenizer.from_pretrained("facebook/hubert-large-ls960-ft")
     model = HubertVQ()

+ 3 - 1
speech_lm/scheduler.py

@@ -1,5 +1,6 @@
 import math
 
+
 def get_cosine_schedule_with_warmup_lr_lambda(
     current_step: int,
     *,
@@ -16,5 +17,6 @@ def get_cosine_schedule_with_warmup_lr_lambda(
     )
 
     return max(
-        final_lr_ratio, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
+        final_lr_ratio,
+        0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
     )

+ 5 - 2
speech_lm/train.py

@@ -6,8 +6,8 @@ import torch
 from lightning.fabric import Fabric
 from omegaconf import DictConfig, OmegaConf
 from tqdm import tqdm
-from transformers.utils import is_flash_attn_available
 from transformers import LlamaForCausalLM
+from transformers.utils import is_flash_attn_available
 
 # Allow TF32 on Ampere GPUs
 torch.set_float32_matmul_precision("high")
@@ -47,7 +47,9 @@ def train(
 
             # Train one step
             with fabric.no_backward_sync(model, enabled=is_accumulating):
-                loss = model(**batch).loss
+                outputs = model(**batch)
+                loss = outputs.loss
+                metrics = getattr(outputs, "metrics", {})
                 fabric.backward(loss)
 
             if is_accumulating:
@@ -68,6 +70,7 @@ def train(
                     "train/loss": loss,
                     "train/lr": optimizer.param_groups[0]["lr"],
                     "train/grad_norm": grad_norm,
+                    **{f"train/{k}": v for k, v in metrics.items()},
                 },
                 step=global_step,
             )