|
|
@@ -2,11 +2,12 @@ from dataclasses import dataclass
|
|
|
from typing import Optional
|
|
|
|
|
|
import torch
|
|
|
-from vector_quantize_pytorch import VectorQuantize
|
|
|
from torch import nn
|
|
|
+from vector_quantize_pytorch import VectorQuantize
|
|
|
+
|
|
|
from speech_lm.models.flash_whisper import (
|
|
|
- FlashWhisperForConditionalGeneration,
|
|
|
FlashWhisperEncoderLayer,
|
|
|
+ FlashWhisperForConditionalGeneration,
|
|
|
)
|
|
|
|
|
|
|
|
|
@@ -15,6 +16,7 @@ class WhisperVQOutput:
|
|
|
loss: torch.Tensor
|
|
|
metrics: dict[str, torch.Tensor]
|
|
|
|
|
|
+
|
|
|
class WhisperVQ(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
@@ -89,7 +91,7 @@ class WhisperVQ(nn.Module):
|
|
|
) -> torch.Tensor:
|
|
|
if attention_mask is not None:
|
|
|
assert attention_mask.ndim == 2, "Attention mask must be 2D"
|
|
|
-
|
|
|
+
|
|
|
# Whisper will downsample by 2
|
|
|
attention_mask = attention_mask[:, ::2]
|
|
|
|
|
|
@@ -101,10 +103,14 @@ class WhisperVQ(nn.Module):
|
|
|
x = hidden_states
|
|
|
if self.downsample:
|
|
|
x = x.reshape(x.shape[0], x.shape[1] // 2, 2, x.shape[2]).mean(dim=2)
|
|
|
- attention_mask = attention_mask[:, ::2]
|
|
|
+
|
|
|
+ if attention_mask is not None:
|
|
|
+ attention_mask = attention_mask[:, ::2]
|
|
|
|
|
|
x = x + self.pre_mlp(self.pre_ln(x))
|
|
|
- quantized, indices, loss = self.quantizer(x, mask=attention_mask.bool())
|
|
|
+ quantized, indices, loss = self.quantizer(
|
|
|
+ x, mask=attention_mask.bool() if attention_mask is not None else None
|
|
|
+ )
|
|
|
|
|
|
# Fill masked positions with pad embedding
|
|
|
if attention_mask is not None:
|
|
|
@@ -121,7 +127,9 @@ class WhisperVQ(nn.Module):
|
|
|
hidden_states = hidden_states.repeat_interleave(2, dim=1)
|
|
|
|
|
|
# Inject position embeddings
|
|
|
- positions = torch.arange(0, hidden_states.shape[1], dtype=torch.long, device=hidden_states.device)
|
|
|
+ positions = torch.arange(
|
|
|
+ 0, hidden_states.shape[1], dtype=torch.long, device=hidden_states.device
|
|
|
+ )
|
|
|
x = hidden_states + self.post_positional_embedding(positions)
|
|
|
|
|
|
# Decode
|
|
|
@@ -177,23 +185,29 @@ class WhisperVQ(nn.Module):
|
|
|
|
|
|
loss = vq_loss + student_ce_loss + kl_loss
|
|
|
|
|
|
- return WhisperVQOutput(loss=loss, metrics={
|
|
|
- "vq_loss": vq_loss,
|
|
|
- "student_ce_loss": student_ce_loss,
|
|
|
- "teacher_ce_loss": teacher_ce_loss,
|
|
|
- "kl_loss": kl_loss,
|
|
|
- })
|
|
|
+ return WhisperVQOutput(
|
|
|
+ loss=loss,
|
|
|
+ metrics={
|
|
|
+ "vq_loss": vq_loss,
|
|
|
+ "student_ce_loss": student_ce_loss,
|
|
|
+ "teacher_ce_loss": teacher_ce_loss,
|
|
|
+ "kl_loss": kl_loss,
|
|
|
+ },
|
|
|
+ )
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
- from transformers import WhisperProcessor
|
|
|
- from speech_lm.datasets.whisper_vq import WhisperVQDataset, WhisperVQCollator
|
|
|
from torch.utils.data import DataLoader
|
|
|
+ from transformers import WhisperProcessor
|
|
|
+
|
|
|
+ from speech_lm.datasets.whisper_vq import WhisperVQCollator, WhisperVQDataset
|
|
|
|
|
|
processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
|
|
|
model = WhisperVQ()
|
|
|
|
|
|
- ds = WhisperVQDataset("filelists/whisper-vq.train.test.filelist", "openai/whisper-medium")
|
|
|
+ ds = WhisperVQDataset(
|
|
|
+ "filelists/whisper-vq.train.test.filelist", "openai/whisper-medium"
|
|
|
+ )
|
|
|
loader = DataLoader(ds, batch_size=8, collate_fn=WhisperVQCollator())
|
|
|
|
|
|
for batch in loader:
|