whisper_vq.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. from dataclasses import dataclass
  2. from typing import Optional
  3. import torch
  4. from torch import nn
  5. from vector_quantize_pytorch import VectorQuantize
  6. from fish_speech.modules.flash_whisper import (
  7. FlashWhisperEncoderLayer,
  8. FlashWhisperForConditionalGeneration,
  9. )
  10. @dataclass
  11. class WhisperVQOutput:
  12. loss: torch.Tensor
  13. metrics: dict[str, torch.Tensor]
  14. class WhisperVQ(nn.Module):
  15. def __init__(
  16. self,
  17. model_name_or_path: str = "openai/whisper-medium",
  18. # Quantization
  19. codebook_dim: int = 32,
  20. codebook_size: int = 4096,
  21. codebook_decay: float = 0.9,
  22. threshold_ema_dead_code: int = 0,
  23. use_cosine_similarity: bool = True,
  24. downsample: bool = True,
  25. # Attention
  26. post_attention_depth: int = 2,
  27. ):
  28. super().__init__()
  29. self.whisper = FlashWhisperForConditionalGeneration.from_pretrained(
  30. model_name_or_path
  31. )
  32. self.whisper.gradient_checkpointing_enable()
  33. # Freeze Whisper
  34. for param in self.whisper.parameters():
  35. param.requires_grad = False
  36. # Store vars
  37. self.downsample = downsample
  38. self.codebook_dim = codebook_dim
  39. self.codebook_size = codebook_size
  40. # Pre-quantization
  41. whisper_config = self.whisper.model.config
  42. encoder_width = whisper_config.encoder_attention_heads * 64
  43. self.pre_ln = nn.LayerNorm(encoder_width)
  44. self.pre_mlp = nn.Sequential(
  45. nn.Linear(encoder_width, whisper_config.encoder_ffn_dim),
  46. nn.GELU(),
  47. nn.Linear(whisper_config.encoder_ffn_dim, encoder_width),
  48. )
  49. # Quantization
  50. self.quantizer = VectorQuantize(
  51. dim=encoder_width,
  52. codebook_size=codebook_size,
  53. codebook_dim=codebook_dim,
  54. decay=codebook_decay,
  55. commitment_weight=1.0,
  56. threshold_ema_dead_code=threshold_ema_dead_code,
  57. use_cosine_sim=use_cosine_similarity,
  58. )
  59. self.pad_embedding = nn.Parameter(torch.randn(encoder_width))
  60. # Post-quantization
  61. self.post_positional_embedding = nn.Embedding(
  62. whisper_config.max_source_positions, encoder_width
  63. )
  64. self.post_attention = nn.Sequential(
  65. *[
  66. FlashWhisperEncoderLayer(
  67. config=whisper_config,
  68. )
  69. for _ in range(post_attention_depth)
  70. ]
  71. )
  72. self.post_ln = nn.LayerNorm(encoder_width)
  73. def encode(
  74. self,
  75. input_features: Optional[torch.Tensor],
  76. attention_mask: Optional[torch.Tensor] = None,
  77. ) -> torch.Tensor:
  78. if attention_mask is not None:
  79. assert attention_mask.ndim == 2, "Attention mask must be 2D"
  80. # Whisper will downsample by 2
  81. attention_mask = attention_mask[:, ::2]
  82. with torch.no_grad():
  83. hidden_states = self.whisper.model.encoder(
  84. input_features,
  85. ).last_hidden_state
  86. x = hidden_states
  87. if self.downsample:
  88. x = x.reshape(x.shape[0], x.shape[1] // 2, 2, x.shape[2]).mean(dim=2)
  89. if attention_mask is not None:
  90. attention_mask = attention_mask[:, ::2]
  91. x = x + self.pre_mlp(self.pre_ln(x))
  92. quantized, indices, loss = self.quantizer(
  93. x, mask=attention_mask.bool() if attention_mask is not None else None
  94. )
  95. # Fill masked positions with pad embedding
  96. if attention_mask is not None:
  97. quantized[attention_mask == 0] = self.pad_embedding
  98. return quantized, indices, loss, hidden_states
  99. def decode(
  100. self,
  101. hidden_states: torch.Tensor,
  102. ) -> torch.Tensor:
  103. # Upsample
  104. if self.downsample:
  105. hidden_states = hidden_states.repeat_interleave(2, dim=1)
  106. # Inject position embeddings
  107. positions = torch.arange(
  108. 0, hidden_states.shape[1], dtype=torch.long, device=hidden_states.device
  109. )
  110. x = hidden_states + self.post_positional_embedding(positions)
  111. # Decode
  112. for layer in self.post_attention:
  113. x = layer(x, None, None)[0]
  114. hidden_states = self.post_ln(hidden_states)
  115. return hidden_states
  116. def forward(
  117. self,
  118. input_features: torch.Tensor,
  119. encoder_attention_mask: torch.Tensor,
  120. decoder_input_ids: torch.Tensor,
  121. decoder_attention_mask: torch.Tensor,
  122. labels: torch.Tensor,
  123. # Audio, not used here
  124. input_values: Optional[torch.Tensor] = None,
  125. ) -> WhisperVQOutput:
  126. quantize, _, vq_loss, teacher_hidden_states = self.encode(
  127. input_features=input_features,
  128. attention_mask=encoder_attention_mask,
  129. )
  130. vq_hidden_states = self.decode(quantize)
  131. # student cross entropy loss
  132. outputs = self.whisper(
  133. encoder_outputs=(vq_hidden_states,),
  134. decoder_input_ids=decoder_input_ids,
  135. decoder_attention_mask=decoder_attention_mask,
  136. labels=labels,
  137. )
  138. student_ce_loss = outputs.loss
  139. student_logits = outputs.logits
  140. # teacher cross entropy loss
  141. with torch.no_grad():
  142. outputs = self.whisper(
  143. encoder_outputs=(teacher_hidden_states,),
  144. decoder_input_ids=decoder_input_ids,
  145. decoder_attention_mask=decoder_attention_mask,
  146. labels=labels,
  147. )
  148. teacher_ce_loss = outputs.loss
  149. teacher_logits = outputs.logits
  150. # KL divergence
  151. kl_loss = nn.functional.kl_div(
  152. nn.functional.log_softmax(student_logits, dim=-1),
  153. nn.functional.softmax(teacher_logits, dim=-1),
  154. reduction="batchmean",
  155. )
  156. loss = vq_loss + student_ce_loss + kl_loss
  157. return WhisperVQOutput(
  158. loss=loss,
  159. metrics={
  160. "vq_loss": vq_loss,
  161. "student_ce_loss": student_ce_loss,
  162. "teacher_ce_loss": teacher_ce_loss,
  163. "kl_loss": kl_loss,
  164. },
  165. )
  166. if __name__ == "__main__":
  167. from torch.utils.data import DataLoader
  168. from transformers import WhisperProcessor
  169. from fish_speech.datasets.whisper_vq import WhisperVQCollator, WhisperVQDataset
  170. processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
  171. model = WhisperVQ()
  172. ds = WhisperVQDataset(
  173. "filelists/whisper-vq.train.test.filelist", "openai/whisper-medium"
  174. )
  175. loader = DataLoader(ds, batch_size=8, collate_fn=WhisperVQCollator())
  176. for batch in loader:
  177. output = model(**batch)
  178. print(output)
  179. break