hubert_vq.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. from dataclasses import dataclass
  2. from typing import Optional
  3. import torch
  4. from encodec.quantization.core_vq import VectorQuantization
  5. from torch import nn
  6. from transformers import HubertModel
  7. class HubertVQ(nn.Module):
  8. def __init__(
  9. self,
  10. model_name_or_path: str = "facebook/hubert-large-ls960-ft",
  11. vq_layer: int = -4, # the layer to extract the quantized features
  12. codebook_size: int = 1024,
  13. trainable_layers_before_vq: int = 2,
  14. trainable_layers_after_vq: int = 2,
  15. ):
  16. super().__init__()
  17. self.hubert = HubertModel.from_pretrained(model_name_or_path)
  18. self.vq_layer = (
  19. (self.hubert.config.num_hidden_layers + vq_layer)
  20. if vq_layer < 0
  21. else vq_layer
  22. )
  23. self.trainable_layers_before_vq = trainable_layers_before_vq
  24. self.trainable_layers_after_vq = trainable_layers_after_vq
  25. assert (
  26. self.vq_layer >= trainable_layers_before_vq
  27. and self.vq_layer
  28. < self.hubert.config.num_hidden_layers - trainable_layers_after_vq
  29. ), "vq_layer must be between trainable_layers_before_vq and num_hidden_layers - trainable_layers_after_vq"
  30. # Freeze both feature extractor & lm head
  31. for param in self.hubert.parameters():
  32. param.requires_grad = False
  33. # Unfreeze layers between vq_layer - trainable_layers_before_vq and vq_layer + trainable_layers_after_vq
  34. for param in self.hubert.encoder.layers[
  35. self.vq_layer
  36. - trainable_layers_before_vq : self.vq_layer
  37. + trainable_layers_after_vq
  38. ].parameters():
  39. param.requires_grad = True
  40. # Quantization
  41. self.quantizer = VectorQuantization(
  42. codebook_size=codebook_size,
  43. dim=self.hubert.config.hidden_size,
  44. kmeans_init=False,
  45. )
  46. @torch.no_grad()
  47. def _get_attention_mask(
  48. self, hidden_states: torch.Tensor, attention_mask: torch.Tensor
  49. ) -> tuple[torch.Tensor, torch.Tensor]:
  50. # compute reduced attention_mask corresponding to feature vectors
  51. attention_mask = self.hubert._get_feature_vector_attention_mask(
  52. hidden_states.shape[1], attention_mask
  53. )
  54. # make sure padded tokens are not attended to
  55. expand_attention_mask = attention_mask.unsqueeze(-1).repeat(
  56. 1, 1, hidden_states.shape[2]
  57. )
  58. hidden_states[~expand_attention_mask] = 0
  59. # extend attention_mask
  60. attention_mask = 1.0 - attention_mask[:, None, None, :].to(
  61. dtype=hidden_states.dtype
  62. )
  63. attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
  64. attention_mask = attention_mask.expand(
  65. attention_mask.shape[0],
  66. 1,
  67. attention_mask.shape[-1],
  68. attention_mask.shape[-1],
  69. )
  70. return hidden_states, attention_mask
  71. def encode(
  72. self,
  73. input_values: Optional[torch.Tensor],
  74. attention_mask: Optional[torch.Tensor] = None,
  75. mask_time_indices: Optional[torch.FloatTensor] = None,
  76. ) -> torch.Tensor:
  77. with torch.no_grad():
  78. # Extract features
  79. extract_features = self.hubert.feature_extractor(input_values)
  80. extract_features = extract_features.transpose(1, 2)
  81. hidden_states = self.hubert.feature_projection(extract_features)
  82. hidden_states = self.hubert._mask_hidden_states(
  83. hidden_states, mask_time_indices=mask_time_indices
  84. )
  85. position_embeddings = self.hubert.encoder.pos_conv_embed(hidden_states)
  86. hidden_states = hidden_states + position_embeddings
  87. if attention_mask is not None:
  88. # compute reduced attention_mask corresponding to feature vectors
  89. hidden_states, attention_mask = self._get_attention_mask(
  90. hidden_states, attention_mask
  91. )
  92. # Only do layer norm if do_stable_layer_norm is False
  93. if self.hubert.config.do_stable_layer_norm is False:
  94. hidden_states = self.hubert.encoder.layer_norm(hidden_states)
  95. hidden_states = self.hubert.encoder.dropout(hidden_states)
  96. # Execute transformer
  97. for idx, layer_module in enumerate(self.hubert.encoder.layers[: self.vq_layer]):
  98. if idx < self.vq_layer - self.trainable_layers_before_vq:
  99. with torch.no_grad():
  100. hidden_states = layer_module(hidden_states, attention_mask)[0]
  101. else:
  102. hidden_states = layer_module(hidden_states, attention_mask)[0]
  103. return hidden_states
  104. @torch.no_grad()
  105. def decode(
  106. self,
  107. hidden_states: torch.Tensor,
  108. attention_mask: Optional[torch.Tensor] = None,
  109. ) -> torch.Tensor:
  110. if attention_mask is not None:
  111. # compute reduced attention_mask corresponding to feature vectors
  112. _, attention_mask = self._get_attention_mask(
  113. hidden_states.clone(), attention_mask
  114. )
  115. # Execute transformer
  116. for idx, layer_module in enumerate(self.hubert.encoder.layers[self.vq_layer :]):
  117. if idx >= self.trainable_layers_after_vq:
  118. with torch.no_grad():
  119. hidden_states = layer_module(hidden_states, attention_mask)[0]
  120. else:
  121. hidden_states = layer_module(hidden_states, attention_mask)[0]
  122. with torch.no_grad():
  123. # Only do layer norm if do_stable_layer_norm is False
  124. if self.hubert.config.do_stable_layer_norm is False:
  125. hidden_states = self.hubert.encoder.last_layer_norm(hidden_states)
  126. else:
  127. hidden_states = self.hubert.encoder.layer_norm(hidden_states)
  128. return hidden_states
  129. def forward(
  130. self,
  131. input_values: Optional[torch.Tensor],
  132. attention_mask: Optional[torch.Tensor] = None,
  133. mask_time_indices: Optional[torch.FloatTensor] = None,
  134. ):
  135. hidden_states = self.encode(
  136. input_values,
  137. attention_mask=attention_mask,
  138. mask_time_indices=mask_time_indices,
  139. )
  140. # Quantize
  141. quantize, _, vq_loss = self.quantizer(hidden_states.transpose(1, 2))
  142. quantize = quantize.transpose(1, 2)
  143. # Inject position embeddings
  144. with torch.no_grad():
  145. position_embeddings = self.hubert.encoder.pos_conv_embed(hidden_states)
  146. quantize = quantize + position_embeddings
  147. # Decode
  148. hidden_states = self.decode(quantize, attention_mask=attention_mask)
  149. return hidden_states, vq_loss
  150. @dataclass
  151. class HubertVQOutput:
  152. loss: torch.Tensor
  153. metrics: dict[str, torch.Tensor]
  154. class HubertVQDistill(nn.Module):
  155. def __init__(
  156. self,
  157. model_name_or_path: str = "facebook/hubert-large-ls960-ft",
  158. vq_layer: int = -4, # the layer to extract the quantized features
  159. codebook_size: int = 1024,
  160. trainable_layers_before_vq: int = 2,
  161. trainable_layers_after_vq: int = 2,
  162. vq_loss_weight: float = 1.0,
  163. ):
  164. super().__init__()
  165. self.hubert_vq = HubertVQ(
  166. model_name_or_path=model_name_or_path,
  167. vq_layer=vq_layer,
  168. codebook_size=codebook_size,
  169. trainable_layers_before_vq=trainable_layers_before_vq,
  170. trainable_layers_after_vq=trainable_layers_after_vq,
  171. )
  172. self.hubert_teacher = HubertModel.from_pretrained(model_name_or_path)
  173. self.vq_loss_weight = vq_loss_weight
  174. # Freeze teacher
  175. for param in self.hubert_teacher.parameters():
  176. param.requires_grad = False
  177. def forward(
  178. self,
  179. input_values: Optional[torch.Tensor],
  180. attention_mask: Optional[torch.Tensor] = None,
  181. mask_time_indices: Optional[torch.FloatTensor] = None,
  182. ) -> HubertVQOutput:
  183. hidden_states, vq_loss = self.hubert_vq(
  184. input_values,
  185. attention_mask=attention_mask,
  186. mask_time_indices=mask_time_indices,
  187. )
  188. # Teacher
  189. with torch.no_grad():
  190. teacher_hidden_states = self.hubert_teacher(
  191. input_values,
  192. attention_mask=attention_mask,
  193. mask_time_indices=mask_time_indices,
  194. ).last_hidden_state
  195. distill_loss = torch.nn.functional.mse_loss(
  196. hidden_states, teacher_hidden_states
  197. )
  198. loss = distill_loss + vq_loss * self.vq_loss_weight
  199. metrics = {
  200. "distill_loss": distill_loss,
  201. "vq_loss": vq_loss,
  202. }
  203. return HubertVQOutput(loss=loss, metrics=metrics)
  204. if __name__ == "__main__":
  205. from datasets import load_dataset
  206. from transformers import Wav2Vec2Tokenizer
  207. processor = Wav2Vec2Tokenizer.from_pretrained("facebook/hubert-large-ls960-ft")
  208. model = HubertVQ()
  209. model.train()
  210. print("Loaded model")
  211. optim = torch.optim.Adam(model.parameters(), lr=1e-4)
  212. gt_hubert = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
  213. gt_hubert.train()
  214. print("Loaded ground truth model")
  215. ds = load_dataset(
  216. "patrickvonplaten/librispeech_asr_dummy", "clean", split="validation"
  217. )
  218. print("Loaded dataset")
  219. input_values = processor(
  220. ds[0]["audio"]["array"], return_tensors="pt"
  221. ) # Batch size 1
  222. optim.zero_grad()
  223. # hidden_states = model.decode(model.encode(**input_values))
  224. hidden_states, vq_loss = model(**input_values)
  225. print(hidden_states, vq_loss)
  226. gt = gt_hubert(**input_values).last_hidden_state
  227. loss = torch.nn.functional.mse_loss(hidden_states, gt)
  228. print(loss)
  229. total_loss = loss + vq_loss
  230. total_loss.backward()
  231. optim.step()
  232. print("Backward pass done")