hubert_vq.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. from dataclasses import dataclass
  2. from typing import Optional
  3. import torch
  4. from encodec.quantization.core_vq import VectorQuantization
  5. from torch import nn
  6. from transformers import HubertModel
  7. class HubertVQ(nn.Module):
  8. def __init__(
  9. self,
  10. model_name_or_path: str = "facebook/hubert-large-ls960-ft",
  11. vq_layer: int = -4, # the layer to extract the quantized features
  12. codebook_size: int = 1024,
  13. trainable_layers_before_vq: int = 2,
  14. trainable_layers_after_vq: int = 2,
  15. ):
  16. super().__init__()
  17. self.hubert = HubertModel.from_pretrained(model_name_or_path)
  18. self.vq_layer = (
  19. (self.hubert.config.num_hidden_layers + vq_layer)
  20. if vq_layer < 0
  21. else vq_layer
  22. )
  23. self.trainable_layers_before_vq = trainable_layers_before_vq
  24. self.trainable_layers_after_vq = trainable_layers_after_vq
  25. assert (
  26. self.vq_layer >= trainable_layers_before_vq
  27. and self.vq_layer
  28. < self.hubert.config.num_hidden_layers - trainable_layers_after_vq
  29. ), "vq_layer must be between trainable_layers_before_vq and num_hidden_layers - trainable_layers_after_vq"
  30. # Freeze both feature extractor & lm head
  31. for param in self.hubert.parameters():
  32. param.requires_grad = False
  33. # Unfreeze layers between vq_layer - trainable_layers_before_vq and vq_layer + trainable_layers_after_vq
  34. for param in self.hubert.encoder.layers[
  35. self.vq_layer
  36. - trainable_layers_before_vq : self.vq_layer
  37. + trainable_layers_after_vq
  38. ].parameters():
  39. param.requires_grad = True
  40. # Quantization
  41. self.quantizer_ln = nn.LayerNorm(self.hubert.config.hidden_size)
  42. self.quantizer = VectorQuantization(
  43. codebook_size=codebook_size,
  44. dim=self.hubert.config.hidden_size,
  45. kmeans_init=False,
  46. )
  47. @torch.no_grad()
  48. def _get_attention_mask(
  49. self, hidden_states: torch.Tensor, attention_mask: torch.Tensor
  50. ) -> tuple[torch.Tensor, torch.Tensor]:
  51. # compute reduced attention_mask corresponding to feature vectors
  52. attention_mask = self.hubert._get_feature_vector_attention_mask(
  53. hidden_states.shape[1], attention_mask
  54. )
  55. # make sure padded tokens are not attended to
  56. expand_attention_mask = attention_mask.unsqueeze(-1).repeat(
  57. 1, 1, hidden_states.shape[2]
  58. )
  59. hidden_states[~expand_attention_mask] = 0
  60. # extend attention_mask
  61. attention_mask = 1.0 - attention_mask[:, None, None, :].to(
  62. dtype=hidden_states.dtype
  63. )
  64. attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
  65. attention_mask = attention_mask.expand(
  66. attention_mask.shape[0],
  67. 1,
  68. attention_mask.shape[-1],
  69. attention_mask.shape[-1],
  70. )
  71. return hidden_states, attention_mask
  72. def encode(
  73. self,
  74. input_values: Optional[torch.Tensor],
  75. attention_mask: Optional[torch.Tensor] = None,
  76. mask_time_indices: Optional[torch.FloatTensor] = None,
  77. ) -> torch.Tensor:
  78. with torch.no_grad():
  79. # Extract features
  80. extract_features = self.hubert.feature_extractor(input_values)
  81. extract_features = extract_features.transpose(1, 2)
  82. hidden_states = self.hubert.feature_projection(extract_features)
  83. hidden_states = self.hubert._mask_hidden_states(
  84. hidden_states, mask_time_indices=mask_time_indices
  85. )
  86. position_embeddings = self.hubert.encoder.pos_conv_embed(hidden_states)
  87. hidden_states = hidden_states + position_embeddings
  88. if attention_mask is not None:
  89. # compute reduced attention_mask corresponding to feature vectors
  90. hidden_states, attention_mask = self._get_attention_mask(
  91. hidden_states, attention_mask
  92. )
  93. # Only do layer norm if do_stable_layer_norm is False
  94. if self.hubert.config.do_stable_layer_norm is False:
  95. hidden_states = self.hubert.encoder.layer_norm(hidden_states)
  96. hidden_states = self.hubert.encoder.dropout(hidden_states)
  97. # Execute transformer
  98. for idx, layer_module in enumerate(self.hubert.encoder.layers[: self.vq_layer]):
  99. if idx < self.vq_layer - self.trainable_layers_before_vq:
  100. with torch.no_grad():
  101. hidden_states = layer_module(hidden_states, attention_mask)[0]
  102. else:
  103. hidden_states = layer_module(hidden_states, attention_mask)[0]
  104. return hidden_states
  105. @torch.no_grad()
  106. def decode(
  107. self,
  108. hidden_states: torch.Tensor,
  109. attention_mask: Optional[torch.Tensor] = None,
  110. ) -> torch.Tensor:
  111. if attention_mask is not None:
  112. # compute reduced attention_mask corresponding to feature vectors
  113. _, attention_mask = self._get_attention_mask(
  114. hidden_states.clone(), attention_mask
  115. )
  116. # Execute transformer
  117. for idx, layer_module in enumerate(self.hubert.encoder.layers[self.vq_layer :]):
  118. if idx >= self.trainable_layers_after_vq:
  119. with torch.no_grad():
  120. hidden_states = layer_module(hidden_states, attention_mask)[0]
  121. else:
  122. hidden_states = layer_module(hidden_states, attention_mask)[0]
  123. with torch.no_grad():
  124. # Only do layer norm if do_stable_layer_norm is False
  125. if self.hubert.config.do_stable_layer_norm is False:
  126. hidden_states = self.hubert.encoder.last_layer_norm(hidden_states)
  127. else:
  128. hidden_states = self.hubert.encoder.layer_norm(hidden_states)
  129. return hidden_states
  130. def forward(
  131. self,
  132. input_values: Optional[torch.Tensor],
  133. attention_mask: Optional[torch.Tensor] = None,
  134. mask_time_indices: Optional[torch.FloatTensor] = None,
  135. ):
  136. hidden_states = self.encode(
  137. input_values,
  138. attention_mask=attention_mask,
  139. mask_time_indices=mask_time_indices,
  140. )
  141. # Quantize
  142. hidden_states = self.quantizer_ln(hidden_states)
  143. quantize, _, vq_loss = self.quantizer(hidden_states.transpose(1, 2))
  144. quantize = quantize.transpose(1, 2)
  145. # Inject position embeddings
  146. with torch.no_grad():
  147. position_embeddings = self.hubert.encoder.pos_conv_embed(hidden_states)
  148. quantize = quantize + position_embeddings
  149. # Decode
  150. hidden_states = self.decode(quantize, attention_mask=attention_mask)
  151. return hidden_states, vq_loss
  152. @dataclass
  153. class HubertVQOutput:
  154. loss: torch.Tensor
  155. metrics: dict[str, torch.Tensor]
  156. class HubertVQDistill(nn.Module):
  157. def __init__(
  158. self,
  159. model_name_or_path: str = "facebook/hubert-large-ls960-ft",
  160. vq_layer: int = -4, # the layer to extract the quantized features
  161. codebook_size: int = 1024,
  162. trainable_layers_before_vq: int = 2,
  163. trainable_layers_after_vq: int = 2,
  164. vq_loss_weight: float = 1.0,
  165. ):
  166. super().__init__()
  167. self.hubert_vq = HubertVQ(
  168. model_name_or_path=model_name_or_path,
  169. vq_layer=vq_layer,
  170. codebook_size=codebook_size,
  171. trainable_layers_before_vq=trainable_layers_before_vq,
  172. trainable_layers_after_vq=trainable_layers_after_vq,
  173. )
  174. self.hubert_teacher = HubertModel.from_pretrained(model_name_or_path)
  175. self.vq_loss_weight = vq_loss_weight
  176. # Freeze teacher
  177. for param in self.hubert_teacher.parameters():
  178. param.requires_grad = False
  179. def forward(
  180. self,
  181. input_values: Optional[torch.Tensor],
  182. attention_mask: Optional[torch.Tensor] = None,
  183. mask_time_indices: Optional[torch.FloatTensor] = None,
  184. ) -> HubertVQOutput:
  185. hidden_states, vq_loss = self.hubert_vq(
  186. input_values,
  187. attention_mask=attention_mask,
  188. mask_time_indices=mask_time_indices,
  189. )
  190. # Teacher
  191. with torch.no_grad():
  192. teacher_hidden_states = self.hubert_teacher(
  193. input_values,
  194. attention_mask=attention_mask,
  195. mask_time_indices=mask_time_indices,
  196. ).last_hidden_state
  197. distill_loss = torch.nn.functional.mse_loss(
  198. hidden_states, teacher_hidden_states
  199. )
  200. loss = distill_loss + vq_loss * self.vq_loss_weight
  201. metrics = {
  202. "distill_loss": distill_loss,
  203. "vq_loss": vq_loss,
  204. }
  205. return HubertVQOutput(loss=loss, metrics=metrics)
  206. if __name__ == "__main__":
  207. from datasets import load_dataset
  208. from transformers import Wav2Vec2Tokenizer
  209. processor = Wav2Vec2Tokenizer.from_pretrained("facebook/hubert-large-ls960-ft")
  210. model = HubertVQ()
  211. model.train()
  212. print("Loaded model")
  213. optim = torch.optim.Adam(model.parameters(), lr=1e-4)
  214. gt_hubert = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
  215. gt_hubert.train()
  216. print("Loaded ground truth model")
  217. ds = load_dataset(
  218. "patrickvonplaten/librispeech_asr_dummy", "clean", split="validation"
  219. )
  220. print("Loaded dataset")
  221. input_values = processor(
  222. ds[0]["audio"]["array"], return_tensors="pt"
  223. ) # Batch size 1
  224. optim.zero_grad()
  225. # hidden_states = model.decode(model.encode(**input_values))
  226. hidden_states, vq_loss = model(**input_values)
  227. print(hidden_states, vq_loss)
  228. gt = gt_hubert(**input_values).last_hidden_state
  229. loss = torch.nn.functional.mse_loss(hidden_states, gt)
  230. print(loss)
  231. total_loss = loss + vq_loss
  232. total_loss.backward()
  233. optim.step()
  234. print("Backward pass done")