|
|
@@ -1,19 +1,21 @@
|
|
|
+from dataclasses import dataclass
|
|
|
from pathlib import Path
|
|
|
|
|
|
import librosa
|
|
|
import torch
|
|
|
from torch.utils.data import Dataset
|
|
|
from transformers import WhisperProcessor
|
|
|
-from dataclasses import dataclass
|
|
|
-from whisper.audio import log_mel_spectrogram, load_audio, pad_or_trim
|
|
|
+from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
|
|
|
+
|
|
|
|
|
|
class WhisperVQDataset(Dataset):
|
|
|
- def __init__(self, filelist: str, model_name_or_path: str = "openai/whisper-medium"):
|
|
|
+ def __init__(
|
|
|
+ self, filelist: str, model_name_or_path: str = "openai/whisper-medium"
|
|
|
+ ):
|
|
|
super().__init__()
|
|
|
|
|
|
self.files = [
|
|
|
- Path(line.strip())
|
|
|
- for line in Path(filelist).read_text().splitlines()
|
|
|
+ Path(line.strip()) for line in Path(filelist).read_text().splitlines()
|
|
|
]
|
|
|
self.processor = WhisperProcessor.from_pretrained(model_name_or_path)
|
|
|
|
|
|
@@ -30,7 +32,10 @@ class WhisperVQDataset(Dataset):
|
|
|
input_ids = file.with_suffix(".whisper.txt").read_text().strip().split("\t")[0]
|
|
|
input_ids = [int(x) for x in input_ids.split(",")]
|
|
|
|
|
|
- while input_ids[-1] in [self.processor.tokenizer.pad_token_id, self.processor.tokenizer.eos_token_id]:
|
|
|
+ while input_ids[-1] in [
|
|
|
+ self.processor.tokenizer.pad_token_id,
|
|
|
+ self.processor.tokenizer.eos_token_id,
|
|
|
+ ]:
|
|
|
input_ids.pop()
|
|
|
|
|
|
input_ids.append(self.processor.tokenizer.eos_token_id)
|
|
|
@@ -59,11 +64,17 @@ class WhisperVQCollator:
|
|
|
|
|
|
for data in batch:
|
|
|
values_length = data["input_values"].shape[-1]
|
|
|
- x = torch.nn.functional.pad(data["input_values"], (0, max_values_length - values_length))
|
|
|
+ x = torch.nn.functional.pad(
|
|
|
+ data["input_values"], (0, max_values_length - values_length)
|
|
|
+ )
|
|
|
input_values.append(x)
|
|
|
|
|
|
ids_length = data["input_ids"].shape[-1]
|
|
|
- ids = torch.nn.functional.pad(data["input_ids"], (0, max_ids_length - ids_length), value=self.processor.tokenizer.pad_token_id)
|
|
|
+ ids = torch.nn.functional.pad(
|
|
|
+ data["input_ids"],
|
|
|
+ (0, max_ids_length - ids_length),
|
|
|
+ value=self.processor.tokenizer.pad_token_id,
|
|
|
+ )
|
|
|
decoder_input_ids.append(ids)
|
|
|
|
|
|
x = torch.zeros(max_ids_length, dtype=torch.float)
|
|
|
@@ -74,26 +85,30 @@ class WhisperVQCollator:
|
|
|
decoder_attention_mask = torch.stack(decoder_attention_mask)
|
|
|
labels = decoder_input_ids.clone()
|
|
|
labels[decoder_attention_mask == 0] = -100
|
|
|
+ labels[:, :4] = -100 # BOS, LANG, TRANSCRIBE, NOTIMESTAMPS
|
|
|
|
|
|
return {
|
|
|
"input_values": torch.stack(input_values),
|
|
|
"input_features": input_features,
|
|
|
"decoder_input_ids": decoder_input_ids[:, :-1],
|
|
|
"decoder_attention_mask": decoder_attention_mask[:, :-1],
|
|
|
- "labels": labels[:, 1:]
|
|
|
+ "labels": labels[:, 1:],
|
|
|
}
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
import soundfile as sf
|
|
|
from torch.utils.data import DataLoader
|
|
|
+
|
|
|
from speech_lm.models.flash_whisper import FlashWhisperForConditionalGeneration
|
|
|
|
|
|
dataset = WhisperVQDataset("test.filelist")
|
|
|
dataloader = DataLoader(
|
|
|
dataset, batch_size=4, shuffle=True, collate_fn=WhisperVQCollator()
|
|
|
)
|
|
|
- whisper = FlashWhisperForConditionalGeneration.from_pretrained("openai/whisper-medium")
|
|
|
+ whisper = FlashWhisperForConditionalGeneration.from_pretrained(
|
|
|
+ "openai/whisper-medium"
|
|
|
+ )
|
|
|
whisper.eval()
|
|
|
# whisper.cuda()
|
|
|
|
|
|
@@ -108,9 +123,14 @@ if __name__ == "__main__":
|
|
|
)
|
|
|
|
|
|
print(outputs, batch["decoder_input_ids"])
|
|
|
- transcriptions = dataset.processor.batch_decode(outputs, skip_special_tokens=True)
|
|
|
+ transcriptions = dataset.processor.batch_decode(
|
|
|
+ outputs, skip_special_tokens=True
|
|
|
+ )
|
|
|
|
|
|
- print(transcriptions, dataset.processor.batch_decode(batch["labels"], skip_special_tokens=True))
|
|
|
+ print(
|
|
|
+ transcriptions,
|
|
|
+ dataset.processor.batch_decode(batch["labels"], skip_special_tokens=True),
|
|
|
+ )
|
|
|
sf.write("test.wav", batch["input_values"][0].cpu().numpy(), 16000)
|
|
|
|
|
|
# Calculate loss
|