whisper_vq.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. from dataclasses import dataclass
  2. from pathlib import Path
  3. import librosa
  4. import torch
  5. from torch.utils.data import Dataset
  6. from transformers import WhisperProcessor
  7. from whisper.audio import HOP_LENGTH, load_audio, log_mel_spectrogram, pad_or_trim
  8. class WhisperVQDataset(Dataset):
  9. def __init__(
  10. self, filelist: str, model_name_or_path: str = "openai/whisper-medium"
  11. ):
  12. super().__init__()
  13. self.files = [
  14. Path(line.strip()) for line in Path(filelist).read_text().splitlines()
  15. ]
  16. self.processor = WhisperProcessor.from_pretrained(model_name_or_path)
  17. def __len__(self):
  18. return len(self.files)
  19. def __getitem__(self, idx):
  20. file = self.files[idx]
  21. wav = load_audio(file)
  22. wav_length = wav.shape[-1]
  23. mel_length = wav_length // HOP_LENGTH + 1
  24. wav = pad_or_trim(wav)
  25. wav = torch.from_numpy(wav).float()
  26. input_features = log_mel_spectrogram(wav)
  27. mel_mask = torch.zeros(input_features.shape[1], dtype=torch.float)
  28. mel_mask[:mel_length] = 1
  29. input_ids = file.with_suffix(".whisper.txt").read_text().strip().split("\t")[0]
  30. input_ids = [int(x) for x in input_ids.split(",")]
  31. while input_ids[-1] in [
  32. self.processor.tokenizer.pad_token_id,
  33. self.processor.tokenizer.eos_token_id,
  34. ]:
  35. input_ids.pop()
  36. input_ids.append(self.processor.tokenizer.eos_token_id)
  37. input_ids = torch.tensor(input_ids, dtype=torch.long)
  38. return {
  39. "input_values": wav,
  40. "input_features": input_features,
  41. "input_ids": input_ids,
  42. "mel_mask": mel_mask,
  43. }
  44. @dataclass
  45. class WhisperVQCollator:
  46. processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
  47. def __call__(self, batch):
  48. # -> {"input_values": ..., "input_features": ..., "input_ids": ..., "decoder_attention_mask": ...}
  49. max_values_length = max([x["input_values"].shape[-1] for x in batch])
  50. max_ids_length = max([x["input_ids"].shape[-1] for x in batch])
  51. input_values = []
  52. decoder_attention_mask = []
  53. decoder_input_ids = []
  54. input_features = torch.stack([x["input_features"] for x in batch])
  55. encoder_attention_mask = torch.stack([x["mel_mask"] for x in batch])
  56. for data in batch:
  57. values_length = data["input_values"].shape[-1]
  58. x = torch.nn.functional.pad(
  59. data["input_values"], (0, max_values_length - values_length)
  60. )
  61. input_values.append(x)
  62. ids_length = data["input_ids"].shape[-1]
  63. ids = torch.nn.functional.pad(
  64. data["input_ids"],
  65. (0, max_ids_length - ids_length),
  66. value=self.processor.tokenizer.pad_token_id,
  67. )
  68. decoder_input_ids.append(ids)
  69. x = torch.zeros(max_ids_length, dtype=torch.float)
  70. x[:ids_length] = 1
  71. decoder_attention_mask.append(x)
  72. decoder_input_ids = torch.stack(decoder_input_ids)
  73. decoder_attention_mask = torch.stack(decoder_attention_mask)
  74. labels = decoder_input_ids.clone()
  75. labels[decoder_attention_mask == 0] = -100
  76. labels[:, :4] = -100 # BOS, LANG, TRANSCRIBE, NOTIMESTAMPS
  77. return {
  78. "input_values": torch.stack(input_values),
  79. "input_features": input_features,
  80. "encoder_attention_mask": encoder_attention_mask,
  81. "decoder_input_ids": decoder_input_ids[:, :-1],
  82. "decoder_attention_mask": decoder_attention_mask[:, :-1],
  83. "labels": labels[:, 1:],
  84. }
  85. if __name__ == "__main__":
  86. import soundfile as sf
  87. from torch.utils.data import DataLoader
  88. from transformers import GenerationConfig
  89. from fish_speech.models.whisper_vq import WhisperVQ
  90. from fish_speech.modules.flash_whisper import FlashWhisperForConditionalGeneration
  91. dataset = WhisperVQDataset("filelists/whisper-vq.test.filelist")
  92. dataloader = DataLoader(
  93. dataset, batch_size=4, shuffle=True, collate_fn=WhisperVQCollator()
  94. )
  95. # whisper = FlashWhisperForConditionalGeneration.from_pretrained(
  96. # "openai/whisper-medium"
  97. # )
  98. # whisper.eval()
  99. our_whisper = WhisperVQ()
  100. whisper = our_whisper.whisper
  101. our_whisper.eval()
  102. state_dict = torch.load(
  103. "results/whisper-vq/checkpoints/step_10000.ckpt", map_location="cpu"
  104. )["model"]
  105. our_whisper.load_state_dict(state_dict, strict=True)
  106. # whisper.cuda()
  107. for batch in dataloader:
  108. # batch = {k: v.cuda() for k, v in batch.items()}
  109. print({k: v.shape for k, v in batch.items()})
  110. outputs = whisper.generate(
  111. inputs=batch["input_features"],
  112. max_length=448,
  113. do_sample=False,
  114. )
  115. print(outputs, batch["decoder_input_ids"])
  116. transcriptions = dataset.processor.batch_decode(
  117. outputs, skip_special_tokens=True
  118. )
  119. print(
  120. transcriptions,
  121. dataset.processor.batch_decode(batch["labels"], skip_special_tokens=True),
  122. )
  123. sf.write("test.wav", batch["input_values"][0].cpu().numpy(), 16000)
  124. # Calculate loss
  125. # encoder_outputs = whisper.model.encoder(
  126. # batch["input_features"],
  127. # )
  128. encoder_outputs = our_whisper.decode(
  129. our_whisper.encode(
  130. batch["input_features"],
  131. )[0]
  132. )
  133. decoder_outputs = whisper.generate(
  134. # decoder_input_ids=batch["decoder_input_ids"],
  135. # decoder_attention_mask=batch["decoder_attention_mask"],
  136. # labels=batch["labels"],
  137. # generation_config=GenerationConfig(
  138. # encoder_outputs=(encoder_outputs,)
  139. # ),
  140. encoder_outputs,
  141. max_length=448,
  142. do_sample=False,
  143. # forced_decoder_ids=batch["decoder_input_ids"][:, :4]
  144. forced_decoder_ids=dataset.processor.get_decoder_prompt_ids(
  145. language="english", task="transcribe"
  146. ),
  147. )
  148. print("Our transcript:", dataset.processor.batch_decode(decoder_outputs))
  149. break