|
@@ -1,8 +1,10 @@
|
|
|
|
|
+from dataclasses import dataclass
|
|
|
from typing import Optional
|
|
from typing import Optional
|
|
|
-from transformers import HubertModel
|
|
|
|
|
-from torch import nn
|
|
|
|
|
|
|
+
|
|
|
import torch
|
|
import torch
|
|
|
from encodec.quantization.core_vq import VectorQuantization
|
|
from encodec.quantization.core_vq import VectorQuantization
|
|
|
|
|
+from torch import nn
|
|
|
|
|
+from transformers import HubertModel
|
|
|
|
|
|
|
|
|
|
|
|
|
class HubertVQ(nn.Module):
|
|
class HubertVQ(nn.Module):
|
|
@@ -177,10 +179,76 @@ class HubertVQ(nn.Module):
|
|
|
return hidden_states, vq_loss
|
|
return hidden_states, vq_loss
|
|
|
|
|
|
|
|
|
|
|
|
|
-# class HubertVQ
|
|
|
|
|
|
|
+@dataclass
|
|
|
|
|
+class HubertVQOutput:
|
|
|
|
|
+ loss: torch.Tensor
|
|
|
|
|
+ metrics: dict[str, torch.Tensor]
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class HubertVQDistill(nn.Module):
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self,
|
|
|
|
|
+ model_name_or_path: str = "facebook/hubert-large-ls960-ft",
|
|
|
|
|
+ vq_layer: int = -4, # the layer to extract the quantized features
|
|
|
|
|
+ codebook_size: int = 1024,
|
|
|
|
|
+ trainable_layers_before_vq: int = 2,
|
|
|
|
|
+ trainable_layers_after_vq: int = 2,
|
|
|
|
|
+ vq_loss_weight: float = 1.0,
|
|
|
|
|
+ ):
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+
|
|
|
|
|
+ self.hubert_vq = HubertVQ(
|
|
|
|
|
+ model_name_or_path=model_name_or_path,
|
|
|
|
|
+ vq_layer=vq_layer,
|
|
|
|
|
+ codebook_size=codebook_size,
|
|
|
|
|
+ trainable_layers_before_vq=trainable_layers_before_vq,
|
|
|
|
|
+ trainable_layers_after_vq=trainable_layers_after_vq,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ self.hubert_teacher = HubertModel.from_pretrained(model_name_or_path)
|
|
|
|
|
+ self.vq_loss_weight = vq_loss_weight
|
|
|
|
|
+
|
|
|
|
|
+ # Freeze teacher
|
|
|
|
|
+ for param in self.hubert_teacher.parameters():
|
|
|
|
|
+ param.requires_grad = False
|
|
|
|
|
+
|
|
|
|
|
+ def forward(
|
|
|
|
|
+ self,
|
|
|
|
|
+ input_values: Optional[torch.Tensor],
|
|
|
|
|
+ attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
|
+ mask_time_indices: Optional[torch.FloatTensor] = None,
|
|
|
|
|
+ ) -> HubertVQOutput:
|
|
|
|
|
+ hidden_states, vq_loss = self.hubert_vq(
|
|
|
|
|
+ input_values,
|
|
|
|
|
+ attention_mask=attention_mask,
|
|
|
|
|
+ mask_time_indices=mask_time_indices,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # Teacher
|
|
|
|
|
+ with torch.no_grad():
|
|
|
|
|
+ teacher_hidden_states = self.hubert_teacher(
|
|
|
|
|
+ input_values,
|
|
|
|
|
+ attention_mask=attention_mask,
|
|
|
|
|
+ mask_time_indices=mask_time_indices,
|
|
|
|
|
+ ).last_hidden_state
|
|
|
|
|
+
|
|
|
|
|
+ distill_loss = torch.nn.functional.mse_loss(
|
|
|
|
|
+ hidden_states, teacher_hidden_states
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ loss = distill_loss + vq_loss * self.vq_loss_weight
|
|
|
|
|
+
|
|
|
|
|
+ metrics = {
|
|
|
|
|
+ "distill_loss": distill_loss,
|
|
|
|
|
+ "vq_loss": vq_loss,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return HubertVQOutput(loss=loss, metrics=metrics)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
|
- from transformers import Wav2Vec2Tokenizer
|
|
|
|
|
from datasets import load_dataset
|
|
from datasets import load_dataset
|
|
|
|
|
+ from transformers import Wav2Vec2Tokenizer
|
|
|
|
|
|
|
|
processor = Wav2Vec2Tokenizer.from_pretrained("facebook/hubert-large-ls960-ft")
|
|
processor = Wav2Vec2Tokenizer.from_pretrained("facebook/hubert-large-ls960-ft")
|
|
|
model = HubertVQ()
|
|
model = HubertVQ()
|