hubert_vq.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. from typing import Optional
  2. from transformers import HubertModel
  3. from torch import nn
  4. import torch
  5. from encodec.quantization.core_vq import VectorQuantization
  6. class HubertVQ(nn.Module):
  7. def __init__(
  8. self,
  9. model_name_or_path: str = "facebook/hubert-large-ls960-ft",
  10. vq_layer: int = -4, # the layer to extract the quantized features
  11. codebook_size: int = 1024,
  12. trainable_layers_before_vq: int = 2,
  13. trainable_layers_after_vq: int = 2,
  14. ):
  15. super().__init__()
  16. self.hubert = HubertModel.from_pretrained(model_name_or_path)
  17. self.vq_layer = (
  18. (self.hubert.config.num_hidden_layers + vq_layer)
  19. if vq_layer < 0
  20. else vq_layer
  21. )
  22. self.trainable_layers_before_vq = trainable_layers_before_vq
  23. self.trainable_layers_after_vq = trainable_layers_after_vq
  24. assert (
  25. self.vq_layer >= trainable_layers_before_vq
  26. and self.vq_layer
  27. < self.hubert.config.num_hidden_layers - trainable_layers_after_vq
  28. ), "vq_layer must be between trainable_layers_before_vq and num_hidden_layers - trainable_layers_after_vq"
  29. # Freeze both feature extractor & lm head
  30. for param in self.hubert.parameters():
  31. param.requires_grad = False
  32. # Unfreeze layers between vq_layer - trainable_layers_before_vq and vq_layer + trainable_layers_after_vq
  33. for param in self.hubert.encoder.layers[
  34. self.vq_layer
  35. - trainable_layers_before_vq : self.vq_layer
  36. + trainable_layers_after_vq
  37. ].parameters():
  38. param.requires_grad = True
  39. # Quantization
  40. self.quantizer = VectorQuantization(
  41. codebook_size=codebook_size,
  42. dim=self.hubert.config.hidden_size,
  43. kmeans_init=False,
  44. )
  45. @torch.no_grad()
  46. def _get_attention_mask(
  47. self, hidden_states: torch.Tensor, attention_mask: torch.Tensor
  48. ) -> tuple[torch.Tensor, torch.Tensor]:
  49. # compute reduced attention_mask corresponding to feature vectors
  50. attention_mask = self.hubert._get_feature_vector_attention_mask(
  51. hidden_states.shape[1], attention_mask
  52. )
  53. # make sure padded tokens are not attended to
  54. expand_attention_mask = attention_mask.unsqueeze(-1).repeat(
  55. 1, 1, hidden_states.shape[2]
  56. )
  57. hidden_states[~expand_attention_mask] = 0
  58. # extend attention_mask
  59. attention_mask = 1.0 - attention_mask[:, None, None, :].to(
  60. dtype=hidden_states.dtype
  61. )
  62. attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
  63. attention_mask = attention_mask.expand(
  64. attention_mask.shape[0],
  65. 1,
  66. attention_mask.shape[-1],
  67. attention_mask.shape[-1],
  68. )
  69. return hidden_states, attention_mask
  70. def encode(
  71. self,
  72. input_values: Optional[torch.Tensor],
  73. attention_mask: Optional[torch.Tensor] = None,
  74. mask_time_indices: Optional[torch.FloatTensor] = None,
  75. ) -> torch.Tensor:
  76. with torch.no_grad():
  77. # Extract features
  78. extract_features = self.hubert.feature_extractor(input_values)
  79. extract_features = extract_features.transpose(1, 2)
  80. hidden_states = self.hubert.feature_projection(extract_features)
  81. hidden_states = self.hubert._mask_hidden_states(
  82. hidden_states, mask_time_indices=mask_time_indices
  83. )
  84. position_embeddings = self.hubert.encoder.pos_conv_embed(hidden_states)
  85. hidden_states = hidden_states + position_embeddings
  86. if attention_mask is not None:
  87. # compute reduced attention_mask corresponding to feature vectors
  88. hidden_states, attention_mask = self._get_attention_mask(
  89. hidden_states, attention_mask
  90. )
  91. # Only do layer norm if do_stable_layer_norm is False
  92. if self.hubert.config.do_stable_layer_norm is False:
  93. hidden_states = self.hubert.encoder.layer_norm(hidden_states)
  94. hidden_states = self.hubert.encoder.dropout(hidden_states)
  95. # Execute transformer
  96. for idx, layer_module in enumerate(self.hubert.encoder.layers[: self.vq_layer]):
  97. if idx < self.vq_layer - self.trainable_layers_before_vq:
  98. with torch.no_grad():
  99. hidden_states = layer_module(hidden_states, attention_mask)[0]
  100. else:
  101. hidden_states = layer_module(hidden_states, attention_mask)[0]
  102. return hidden_states
  103. @torch.no_grad()
  104. def decode(
  105. self,
  106. hidden_states: torch.Tensor,
  107. attention_mask: Optional[torch.Tensor] = None,
  108. ) -> torch.Tensor:
  109. if attention_mask is not None:
  110. # compute reduced attention_mask corresponding to feature vectors
  111. _, attention_mask = self._get_attention_mask(
  112. hidden_states.clone(), attention_mask
  113. )
  114. # Execute transformer
  115. for idx, layer_module in enumerate(self.hubert.encoder.layers[self.vq_layer :]):
  116. if idx >= self.trainable_layers_after_vq:
  117. with torch.no_grad():
  118. hidden_states = layer_module(hidden_states, attention_mask)[0]
  119. else:
  120. hidden_states = layer_module(hidden_states, attention_mask)[0]
  121. with torch.no_grad():
  122. # Only do layer norm if do_stable_layer_norm is False
  123. if self.hubert.config.do_stable_layer_norm is False:
  124. hidden_states = self.hubert.encoder.last_layer_norm(hidden_states)
  125. else:
  126. hidden_states = self.hubert.encoder.layer_norm(hidden_states)
  127. return hidden_states
  128. def forward(
  129. self,
  130. input_values: Optional[torch.Tensor],
  131. attention_mask: Optional[torch.Tensor] = None,
  132. mask_time_indices: Optional[torch.FloatTensor] = None,
  133. ):
  134. hidden_states = self.encode(
  135. input_values,
  136. attention_mask=attention_mask,
  137. mask_time_indices=mask_time_indices,
  138. )
  139. # Quantize
  140. quantize, _, vq_loss = self.quantizer(hidden_states.transpose(1, 2))
  141. quantize = quantize.transpose(1, 2)
  142. # Inject position embeddings
  143. with torch.no_grad():
  144. position_embeddings = self.hubert.encoder.pos_conv_embed(hidden_states)
  145. quantize = quantize + position_embeddings
  146. # Decode
  147. hidden_states = self.decode(quantize, attention_mask=attention_mask)
  148. return hidden_states, vq_loss
  149. # class HubertVQ
  150. if __name__ == "__main__":
  151. from transformers import Wav2Vec2Tokenizer
  152. from datasets import load_dataset
  153. processor = Wav2Vec2Tokenizer.from_pretrained("facebook/hubert-large-ls960-ft")
  154. model = HubertVQ()
  155. model.train()
  156. print("Loaded model")
  157. optim = torch.optim.Adam(model.parameters(), lr=1e-4)
  158. gt_hubert = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
  159. gt_hubert.train()
  160. print("Loaded ground truth model")
  161. ds = load_dataset(
  162. "patrickvonplaten/librispeech_asr_dummy", "clean", split="validation"
  163. )
  164. print("Loaded dataset")
  165. input_values = processor(
  166. ds[0]["audio"]["array"], return_tensors="pt"
  167. ) # Batch size 1
  168. optim.zero_grad()
  169. # hidden_states = model.decode(model.encode(**input_values))
  170. hidden_states, vq_loss = model(**input_values)
  171. print(hidden_states, vq_loss)
  172. gt = gt_hubert(**input_values).last_hidden_state
  173. loss = torch.nn.functional.mse_loss(hidden_states, gt)
  174. print(loss)
  175. total_loss = loss + vq_loss
  176. total_loss.backward()
  177. optim.step()
  178. print("Backward pass done")