| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287 |
- from dataclasses import dataclass
- from typing import Optional
- import torch
- from encodec.quantization.core_vq import VectorQuantization
- from torch import nn
- from transformers import HubertModel
- class HubertVQ(nn.Module):
- def __init__(
- self,
- model_name_or_path: str = "facebook/hubert-large-ls960-ft",
- vq_layer: int = -4, # the layer to extract the quantized features
- codebook_size: int = 1024,
- trainable_layers_before_vq: int = 2,
- trainable_layers_after_vq: int = 2,
- ):
- super().__init__()
- self.hubert = HubertModel.from_pretrained(model_name_or_path)
- self.vq_layer = (
- (self.hubert.config.num_hidden_layers + vq_layer)
- if vq_layer < 0
- else vq_layer
- )
- self.trainable_layers_before_vq = trainable_layers_before_vq
- self.trainable_layers_after_vq = trainable_layers_after_vq
- assert (
- self.vq_layer >= trainable_layers_before_vq
- and self.vq_layer
- < self.hubert.config.num_hidden_layers - trainable_layers_after_vq
- ), "vq_layer must be between trainable_layers_before_vq and num_hidden_layers - trainable_layers_after_vq"
- # Freeze both feature extractor & lm head
- for param in self.hubert.parameters():
- param.requires_grad = False
- # Unfreeze layers between vq_layer - trainable_layers_before_vq and vq_layer + trainable_layers_after_vq
- for param in self.hubert.encoder.layers[
- self.vq_layer
- - trainable_layers_before_vq : self.vq_layer
- + trainable_layers_after_vq
- ].parameters():
- param.requires_grad = True
- # Quantization
- self.quantizer = VectorQuantization(
- codebook_size=codebook_size,
- dim=self.hubert.config.hidden_size,
- kmeans_init=False,
- )
- @torch.no_grad()
- def _get_attention_mask(
- self, hidden_states: torch.Tensor, attention_mask: torch.Tensor
- ) -> tuple[torch.Tensor, torch.Tensor]:
- # compute reduced attention_mask corresponding to feature vectors
- attention_mask = self.hubert._get_feature_vector_attention_mask(
- hidden_states.shape[1], attention_mask
- )
- # make sure padded tokens are not attended to
- expand_attention_mask = attention_mask.unsqueeze(-1).repeat(
- 1, 1, hidden_states.shape[2]
- )
- hidden_states[~expand_attention_mask] = 0
- # extend attention_mask
- attention_mask = 1.0 - attention_mask[:, None, None, :].to(
- dtype=hidden_states.dtype
- )
- attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
- attention_mask = attention_mask.expand(
- attention_mask.shape[0],
- 1,
- attention_mask.shape[-1],
- attention_mask.shape[-1],
- )
- return hidden_states, attention_mask
- def encode(
- self,
- input_values: Optional[torch.Tensor],
- attention_mask: Optional[torch.Tensor] = None,
- mask_time_indices: Optional[torch.FloatTensor] = None,
- ) -> torch.Tensor:
- with torch.no_grad():
- # Extract features
- extract_features = self.hubert.feature_extractor(input_values)
- extract_features = extract_features.transpose(1, 2)
- hidden_states = self.hubert.feature_projection(extract_features)
- hidden_states = self.hubert._mask_hidden_states(
- hidden_states, mask_time_indices=mask_time_indices
- )
- position_embeddings = self.hubert.encoder.pos_conv_embed(hidden_states)
- hidden_states = hidden_states + position_embeddings
- if attention_mask is not None:
- # compute reduced attention_mask corresponding to feature vectors
- hidden_states, attention_mask = self._get_attention_mask(
- hidden_states, attention_mask
- )
- # Only do layer norm if do_stable_layer_norm is False
- if self.hubert.config.do_stable_layer_norm is False:
- hidden_states = self.hubert.encoder.layer_norm(hidden_states)
- hidden_states = self.hubert.encoder.dropout(hidden_states)
- # Execute transformer
- for idx, layer_module in enumerate(self.hubert.encoder.layers[: self.vq_layer]):
- if idx < self.vq_layer - self.trainable_layers_before_vq:
- with torch.no_grad():
- hidden_states = layer_module(hidden_states, attention_mask)[0]
- else:
- hidden_states = layer_module(hidden_states, attention_mask)[0]
- return hidden_states
- @torch.no_grad()
- def decode(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- if attention_mask is not None:
- # compute reduced attention_mask corresponding to feature vectors
- _, attention_mask = self._get_attention_mask(
- hidden_states.clone(), attention_mask
- )
- # Execute transformer
- for idx, layer_module in enumerate(self.hubert.encoder.layers[self.vq_layer :]):
- if idx >= self.trainable_layers_after_vq:
- with torch.no_grad():
- hidden_states = layer_module(hidden_states, attention_mask)[0]
- else:
- hidden_states = layer_module(hidden_states, attention_mask)[0]
- with torch.no_grad():
- # Only do layer norm if do_stable_layer_norm is False
- if self.hubert.config.do_stable_layer_norm is False:
- hidden_states = self.hubert.encoder.last_layer_norm(hidden_states)
- else:
- hidden_states = self.hubert.encoder.layer_norm(hidden_states)
- return hidden_states
- def forward(
- self,
- input_values: Optional[torch.Tensor],
- attention_mask: Optional[torch.Tensor] = None,
- mask_time_indices: Optional[torch.FloatTensor] = None,
- ):
- hidden_states = self.encode(
- input_values,
- attention_mask=attention_mask,
- mask_time_indices=mask_time_indices,
- )
- # Quantize
- quantize, _, vq_loss = self.quantizer(hidden_states.transpose(1, 2))
- quantize = quantize.transpose(1, 2)
- # Inject position embeddings
- with torch.no_grad():
- position_embeddings = self.hubert.encoder.pos_conv_embed(hidden_states)
- quantize = quantize + position_embeddings
- # Decode
- hidden_states = self.decode(quantize, attention_mask=attention_mask)
- return hidden_states, vq_loss
- @dataclass
- class HubertVQOutput:
- loss: torch.Tensor
- metrics: dict[str, torch.Tensor]
- class HubertVQDistill(nn.Module):
- def __init__(
- self,
- model_name_or_path: str = "facebook/hubert-large-ls960-ft",
- vq_layer: int = -4, # the layer to extract the quantized features
- codebook_size: int = 1024,
- trainable_layers_before_vq: int = 2,
- trainable_layers_after_vq: int = 2,
- vq_loss_weight: float = 1.0,
- ):
- super().__init__()
- self.hubert_vq = HubertVQ(
- model_name_or_path=model_name_or_path,
- vq_layer=vq_layer,
- codebook_size=codebook_size,
- trainable_layers_before_vq=trainable_layers_before_vq,
- trainable_layers_after_vq=trainable_layers_after_vq,
- )
- self.hubert_teacher = HubertModel.from_pretrained(model_name_or_path)
- self.vq_loss_weight = vq_loss_weight
- # Freeze teacher
- for param in self.hubert_teacher.parameters():
- param.requires_grad = False
- def forward(
- self,
- input_values: Optional[torch.Tensor],
- attention_mask: Optional[torch.Tensor] = None,
- mask_time_indices: Optional[torch.FloatTensor] = None,
- ) -> HubertVQOutput:
- hidden_states, vq_loss = self.hubert_vq(
- input_values,
- attention_mask=attention_mask,
- mask_time_indices=mask_time_indices,
- )
- # Teacher
- with torch.no_grad():
- teacher_hidden_states = self.hubert_teacher(
- input_values,
- attention_mask=attention_mask,
- mask_time_indices=mask_time_indices,
- ).last_hidden_state
- distill_loss = torch.nn.functional.mse_loss(
- hidden_states, teacher_hidden_states
- )
- loss = distill_loss + vq_loss * self.vq_loss_weight
- metrics = {
- "distill_loss": distill_loss,
- "vq_loss": vq_loss,
- }
- return HubertVQOutput(loss=loss, metrics=metrics)
- if __name__ == "__main__":
- from datasets import load_dataset
- from transformers import Wav2Vec2Tokenizer
- processor = Wav2Vec2Tokenizer.from_pretrained("facebook/hubert-large-ls960-ft")
- model = HubertVQ()
- model.train()
- print("Loaded model")
- optim = torch.optim.Adam(model.parameters(), lr=1e-4)
- gt_hubert = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
- gt_hubert.train()
- print("Loaded ground truth model")
- ds = load_dataset(
- "patrickvonplaten/librispeech_asr_dummy", "clean", split="validation"
- )
- print("Loaded dataset")
- input_values = processor(
- ds[0]["audio"]["array"], return_tensors="pt"
- ) # Batch size 1
- optim.zero_grad()
- # hidden_states = model.decode(model.encode(**input_values))
- hidden_states, vq_loss = model(**input_values)
- print(hidden_states, vq_loss)
- gt = gt_hubert(**input_values).last_hidden_state
- loss = torch.nn.functional.mse_loss(hidden_states, gt)
- print(loss)
- total_loss = loss + vq_loss
- total_loss.backward()
- optim.step()
- print("Backward pass done")
|