|
|
@@ -2,9 +2,10 @@ from dataclasses import dataclass
|
|
|
from pathlib import Path
|
|
|
|
|
|
import librosa
|
|
|
+import numpy as np
|
|
|
import torch
|
|
|
+from lightning import LightningDataModule
|
|
|
from torch.utils.data import Dataset
|
|
|
-from transformers import WhisperProcessor
|
|
|
|
|
|
|
|
|
class VQGANDataset(Dataset):
|
|
|
@@ -19,6 +20,7 @@ class VQGANDataset(Dataset):
|
|
|
root = filelist.parent
|
|
|
|
|
|
self.files = [root / line.strip() for line in filelist.read_text().splitlines()]
|
|
|
+ self.sample_rate = sample_rate
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.files)
|
|
|
@@ -26,129 +28,94 @@ class VQGANDataset(Dataset):
|
|
|
def __getitem__(self, idx):
|
|
|
file = self.files[idx]
|
|
|
|
|
|
+ audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
|
|
|
+ features = np.load(file.with_suffix(".npy")) # (T, 1024)
|
|
|
+
|
|
|
+ return {
|
|
|
+ "audio": torch.from_numpy(audio),
|
|
|
+ "features": torch.from_numpy(features),
|
|
|
+ }
|
|
|
+
|
|
|
|
|
|
@dataclass
|
|
|
-class WhisperVQCollator:
|
|
|
- processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
|
|
|
+class VQGANCollator:
|
|
|
+ hop_length: int = 640
|
|
|
|
|
|
def __call__(self, batch):
|
|
|
- # -> {"input_values": ..., "input_features": ..., "input_ids": ..., "decoder_attention_mask": ...}
|
|
|
- max_values_length = max([x["input_values"].shape[-1] for x in batch])
|
|
|
- max_ids_length = max([x["input_ids"].shape[-1] for x in batch])
|
|
|
-
|
|
|
- input_values = []
|
|
|
- decoder_attention_mask = []
|
|
|
- decoder_input_ids = []
|
|
|
- input_features = torch.stack([x["input_features"] for x in batch])
|
|
|
- encoder_attention_mask = torch.stack([x["mel_mask"] for x in batch])
|
|
|
-
|
|
|
- for data in batch:
|
|
|
- values_length = data["input_values"].shape[-1]
|
|
|
- x = torch.nn.functional.pad(
|
|
|
- data["input_values"], (0, max_values_length - values_length)
|
|
|
- )
|
|
|
- input_values.append(x)
|
|
|
+ audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
|
|
|
+ feature_lengths = torch.tensor([len(x["features"]) for x in batch])
|
|
|
|
|
|
- ids_length = data["input_ids"].shape[-1]
|
|
|
- ids = torch.nn.functional.pad(
|
|
|
- data["input_ids"],
|
|
|
- (0, max_ids_length - ids_length),
|
|
|
- value=self.processor.tokenizer.pad_token_id,
|
|
|
- )
|
|
|
- decoder_input_ids.append(ids)
|
|
|
+ audio_maxlen = audio_lengths.max()
|
|
|
+ feature_maxlen = feature_lengths.max()
|
|
|
|
|
|
- x = torch.zeros(max_ids_length, dtype=torch.float)
|
|
|
- x[:ids_length] = 1
|
|
|
- decoder_attention_mask.append(x)
|
|
|
+ if audio_maxlen % self.hop_length != 0:
|
|
|
+ audio_maxlen += self.hop_length - (audio_maxlen % self.hop_length)
|
|
|
|
|
|
- decoder_input_ids = torch.stack(decoder_input_ids)
|
|
|
- decoder_attention_mask = torch.stack(decoder_attention_mask)
|
|
|
- labels = decoder_input_ids.clone()
|
|
|
- labels[decoder_attention_mask == 0] = -100
|
|
|
- labels[:, :4] = -100 # BOS, LANG, TRANSCRIBE, NOTIMESTAMPS
|
|
|
+ audios, features = [], []
|
|
|
+ for x in batch:
|
|
|
+ audios.append(
|
|
|
+ torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
|
|
|
+ )
|
|
|
+ features.append(
|
|
|
+ torch.nn.functional.pad(
|
|
|
+ x["features"], (0, 0, 0, feature_maxlen - len(x["features"]))
|
|
|
+ )
|
|
|
+ )
|
|
|
|
|
|
return {
|
|
|
- "input_values": torch.stack(input_values),
|
|
|
- "input_features": input_features,
|
|
|
- "encoder_attention_mask": encoder_attention_mask,
|
|
|
- "decoder_input_ids": decoder_input_ids[:, :-1],
|
|
|
- "decoder_attention_mask": decoder_attention_mask[:, :-1],
|
|
|
- "labels": labels[:, 1:],
|
|
|
+ "audios": torch.stack(audios),
|
|
|
+ "features": torch.stack(features),
|
|
|
+ "audio_lengths": audio_lengths,
|
|
|
+ "feature_lengths": feature_lengths,
|
|
|
}
|
|
|
|
|
|
|
|
|
-if __name__ == "__main__":
|
|
|
- import soundfile as sf
|
|
|
- from torch.utils.data import DataLoader
|
|
|
- from transformers import GenerationConfig
|
|
|
-
|
|
|
- from fish_speech.models.whisper_vq import WhisperVQ
|
|
|
- from fish_speech.modules.flash_whisper import FlashWhisperForConditionalGeneration
|
|
|
-
|
|
|
- dataset = WhisperVQDataset("filelists/whisper-vq.test.filelist")
|
|
|
- dataloader = DataLoader(
|
|
|
- dataset, batch_size=4, shuffle=True, collate_fn=WhisperVQCollator()
|
|
|
- )
|
|
|
- # whisper = FlashWhisperForConditionalGeneration.from_pretrained(
|
|
|
- # "openai/whisper-medium"
|
|
|
- # )
|
|
|
- # whisper.eval()
|
|
|
- our_whisper = WhisperVQ()
|
|
|
- whisper = our_whisper.whisper
|
|
|
- our_whisper.eval()
|
|
|
-
|
|
|
- state_dict = torch.load(
|
|
|
- "results/whisper-vq/checkpoints/step_10000.ckpt", map_location="cpu"
|
|
|
- )["model"]
|
|
|
- our_whisper.load_state_dict(state_dict, strict=True)
|
|
|
- # whisper.cuda()
|
|
|
-
|
|
|
- for batch in dataloader:
|
|
|
- # batch = {k: v.cuda() for k, v in batch.items()}
|
|
|
- print({k: v.shape for k, v in batch.items()})
|
|
|
+class VQGANDataModule(LightningDataModule):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ train_dataset: VQGANDataset,
|
|
|
+ val_dataset: VQGANDataset,
|
|
|
+ batch_size: int = 32,
|
|
|
+ hop_length: int = 640,
|
|
|
+ num_workers: int = 4,
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
|
|
|
- outputs = whisper.generate(
|
|
|
- inputs=batch["input_features"],
|
|
|
- max_length=448,
|
|
|
- do_sample=False,
|
|
|
+ self.train_dataset = train_dataset
|
|
|
+ self.val_dataset = val_dataset
|
|
|
+ self.batch_size = batch_size
|
|
|
+ self.hop_length = hop_length
|
|
|
+ self.num_workers = num_workers
|
|
|
+
|
|
|
+ def train_dataloader(self):
|
|
|
+ return DataLoader(
|
|
|
+ self.train_dataset,
|
|
|
+ batch_size=self.batch_size,
|
|
|
+ collate_fn=VQGANCollator(self.hop_length),
|
|
|
+ num_workers=self.num_workers,
|
|
|
+ shuffle=True,
|
|
|
)
|
|
|
|
|
|
- print(outputs, batch["decoder_input_ids"])
|
|
|
- transcriptions = dataset.processor.batch_decode(
|
|
|
- outputs, skip_special_tokens=True
|
|
|
+ def val_dataloader(self):
|
|
|
+ return DataLoader(
|
|
|
+ self.val_dataset,
|
|
|
+ batch_size=self.batch_size,
|
|
|
+ collate_fn=VQGANCollator(self.hop_length),
|
|
|
+ num_workers=self.num_workers,
|
|
|
)
|
|
|
|
|
|
- print(
|
|
|
- transcriptions,
|
|
|
- dataset.processor.batch_decode(batch["labels"], skip_special_tokens=True),
|
|
|
- )
|
|
|
- sf.write("test.wav", batch["input_values"][0].cpu().numpy(), 16000)
|
|
|
-
|
|
|
- # Calculate loss
|
|
|
- # encoder_outputs = whisper.model.encoder(
|
|
|
- # batch["input_features"],
|
|
|
- # )
|
|
|
- encoder_outputs = our_whisper.decode(
|
|
|
- our_whisper.encode(
|
|
|
- batch["input_features"],
|
|
|
- )[0]
|
|
|
- )
|
|
|
|
|
|
- decoder_outputs = whisper.generate(
|
|
|
- # decoder_input_ids=batch["decoder_input_ids"],
|
|
|
- # decoder_attention_mask=batch["decoder_attention_mask"],
|
|
|
- # labels=batch["labels"],
|
|
|
- # generation_config=GenerationConfig(
|
|
|
- # encoder_outputs=(encoder_outputs,)
|
|
|
- # ),
|
|
|
- encoder_outputs,
|
|
|
- max_length=448,
|
|
|
- do_sample=False,
|
|
|
- # forced_decoder_ids=batch["decoder_input_ids"][:, :4]
|
|
|
- forced_decoder_ids=dataset.processor.get_decoder_prompt_ids(
|
|
|
- language="english", task="transcribe"
|
|
|
- ),
|
|
|
- )
|
|
|
+if __name__ == "__main__":
|
|
|
+ from torch.utils.data import DataLoader
|
|
|
|
|
|
- print("Our transcript:", dataset.processor.batch_decode(decoder_outputs))
|
|
|
+ dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
|
|
|
+ dataloader = DataLoader(
|
|
|
+ dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
|
|
|
+ )
|
|
|
+
|
|
|
+ for batch in dataloader:
|
|
|
+ print(batch["audios"].shape)
|
|
|
+ print(batch["features"].shape)
|
|
|
+ print(batch["audio_lengths"])
|
|
|
+ print(batch["feature_lengths"])
|
|
|
break
|