from dataclasses import dataclass from typing import Optional import torch from encodec.quantization.core_vq import VectorQuantization from torch import nn from transformers import HubertModel class HubertVQ(nn.Module): def __init__( self, model_name_or_path: str = "facebook/hubert-large-ls960-ft", vq_layer: int = -4, # the layer to extract the quantized features codebook_size: int = 1024, trainable_layers_before_vq: int = 2, trainable_layers_after_vq: int = 2, ): super().__init__() self.hubert = HubertModel.from_pretrained(model_name_or_path) self.vq_layer = ( (self.hubert.config.num_hidden_layers + vq_layer) if vq_layer < 0 else vq_layer ) self.trainable_layers_before_vq = trainable_layers_before_vq self.trainable_layers_after_vq = trainable_layers_after_vq assert ( self.vq_layer >= trainable_layers_before_vq and self.vq_layer < self.hubert.config.num_hidden_layers - trainable_layers_after_vq ), "vq_layer must be between trainable_layers_before_vq and num_hidden_layers - trainable_layers_after_vq" # Freeze both feature extractor & lm head for param in self.hubert.parameters(): param.requires_grad = False # Unfreeze layers between vq_layer - trainable_layers_before_vq and vq_layer + trainable_layers_after_vq for param in self.hubert.encoder.layers[ self.vq_layer - trainable_layers_before_vq : self.vq_layer + trainable_layers_after_vq ].parameters(): param.requires_grad = True # Quantization self.quantizer_ln = nn.LayerNorm(self.hubert.config.hidden_size) self.quantizer = VectorQuantization( codebook_size=codebook_size, dim=self.hubert.config.hidden_size, kmeans_init=False, ) @torch.no_grad() def _get_attention_mask( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: # compute reduced attention_mask corresponding to feature vectors attention_mask = self.hubert._get_feature_vector_attention_mask( hidden_states.shape[1], attention_mask ) # make sure padded tokens are not attended to expand_attention_mask = attention_mask.unsqueeze(-1).repeat( 1, 1, hidden_states.shape[2] ) hidden_states[~expand_attention_mask] = 0 # extend attention_mask attention_mask = 1.0 - attention_mask[:, None, None, :].to( dtype=hidden_states.dtype ) attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min attention_mask = attention_mask.expand( attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1], ) return hidden_states, attention_mask def encode( self, input_values: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, mask_time_indices: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: with torch.no_grad(): # Extract features extract_features = self.hubert.feature_extractor(input_values) extract_features = extract_features.transpose(1, 2) hidden_states = self.hubert.feature_projection(extract_features) hidden_states = self.hubert._mask_hidden_states( hidden_states, mask_time_indices=mask_time_indices ) position_embeddings = self.hubert.encoder.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings if attention_mask is not None: # compute reduced attention_mask corresponding to feature vectors hidden_states, attention_mask = self._get_attention_mask( hidden_states, attention_mask ) # Only do layer norm if do_stable_layer_norm is False if self.hubert.config.do_stable_layer_norm is False: hidden_states = self.hubert.encoder.layer_norm(hidden_states) hidden_states = self.hubert.encoder.dropout(hidden_states) # Execute transformer for idx, layer_module in enumerate(self.hubert.encoder.layers[: self.vq_layer]): if idx < self.vq_layer - self.trainable_layers_before_vq: with torch.no_grad(): hidden_states = layer_module(hidden_states, attention_mask)[0] else: hidden_states = layer_module(hidden_states, attention_mask)[0] return hidden_states @torch.no_grad() def decode( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: if attention_mask is not None: # compute reduced attention_mask corresponding to feature vectors _, attention_mask = self._get_attention_mask( hidden_states.clone(), attention_mask ) # Execute transformer for idx, layer_module in enumerate(self.hubert.encoder.layers[self.vq_layer :]): if idx >= self.trainable_layers_after_vq: with torch.no_grad(): hidden_states = layer_module(hidden_states, attention_mask)[0] else: hidden_states = layer_module(hidden_states, attention_mask)[0] with torch.no_grad(): # Only do layer norm if do_stable_layer_norm is False if self.hubert.config.do_stable_layer_norm is False: hidden_states = self.hubert.encoder.last_layer_norm(hidden_states) else: hidden_states = self.hubert.encoder.layer_norm(hidden_states) return hidden_states def forward( self, input_values: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, mask_time_indices: Optional[torch.FloatTensor] = None, ): hidden_states = self.encode( input_values, attention_mask=attention_mask, mask_time_indices=mask_time_indices, ) # Quantize hidden_states = self.quantizer_ln(hidden_states) quantize, _, vq_loss = self.quantizer(hidden_states.transpose(1, 2)) quantize = quantize.transpose(1, 2) # Inject position embeddings with torch.no_grad(): position_embeddings = self.hubert.encoder.pos_conv_embed(hidden_states) quantize = quantize + position_embeddings # Decode hidden_states = self.decode(quantize, attention_mask=attention_mask) return hidden_states, vq_loss @dataclass class HubertVQOutput: loss: torch.Tensor metrics: dict[str, torch.Tensor] class HubertVQDistill(nn.Module): def __init__( self, model_name_or_path: str = "facebook/hubert-large-ls960-ft", vq_layer: int = -4, # the layer to extract the quantized features codebook_size: int = 1024, trainable_layers_before_vq: int = 2, trainable_layers_after_vq: int = 2, vq_loss_weight: float = 1.0, ): super().__init__() self.hubert_vq = HubertVQ( model_name_or_path=model_name_or_path, vq_layer=vq_layer, codebook_size=codebook_size, trainable_layers_before_vq=trainable_layers_before_vq, trainable_layers_after_vq=trainable_layers_after_vq, ) self.hubert_teacher = HubertModel.from_pretrained(model_name_or_path) self.vq_loss_weight = vq_loss_weight # Freeze teacher for param in self.hubert_teacher.parameters(): param.requires_grad = False def forward( self, input_values: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, mask_time_indices: Optional[torch.FloatTensor] = None, ) -> HubertVQOutput: hidden_states, vq_loss = self.hubert_vq( input_values, attention_mask=attention_mask, mask_time_indices=mask_time_indices, ) # Teacher with torch.no_grad(): teacher_hidden_states = self.hubert_teacher( input_values, attention_mask=attention_mask, mask_time_indices=mask_time_indices, ).last_hidden_state distill_loss = torch.nn.functional.mse_loss( hidden_states, teacher_hidden_states ) loss = distill_loss + vq_loss * self.vq_loss_weight metrics = { "distill_loss": distill_loss, "vq_loss": vq_loss, } return HubertVQOutput(loss=loss, metrics=metrics) if __name__ == "__main__": from datasets import load_dataset from transformers import Wav2Vec2Tokenizer processor = Wav2Vec2Tokenizer.from_pretrained("facebook/hubert-large-ls960-ft") model = HubertVQ() model.train() print("Loaded model") optim = torch.optim.Adam(model.parameters(), lr=1e-4) gt_hubert = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft") gt_hubert.train() print("Loaded ground truth model") ds = load_dataset( "patrickvonplaten/librispeech_asr_dummy", "clean", split="validation" ) print("Loaded dataset") input_values = processor( ds[0]["audio"]["array"], return_tensors="pt" ) # Batch size 1 optim.zero_grad() # hidden_states = model.decode(model.encode(**input_values)) hidden_states, vq_loss = model(**input_values) print(hidden_states, vq_loss) gt = gt_hubert(**input_values).last_hidden_state loss = torch.nn.functional.mse_loss(hidden_states, gt) print(loss) total_loss = loss + vq_loss total_loss.backward() optim.step() print("Backward pass done")