hubert_vq.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. from pathlib import Path
  2. import librosa
  3. import torch
  4. from torch.utils.data import Dataset
  5. class HubertVQDataset(Dataset):
  6. def __init__(self, filelist: str):
  7. super().__init__()
  8. self.files = Path(filelist).read_text().splitlines()
  9. def __len__(self):
  10. return len(self.files)
  11. def __getitem__(self, idx):
  12. wav, _ = librosa.load(self.files[idx], sr=16000, mono=True)
  13. wav = torch.from_numpy(wav).float()
  14. return wav
  15. class HubertVQCollator:
  16. @staticmethod
  17. def __call__(batch):
  18. # -> {"input_values": ..., "attention_mask": ...}
  19. max_length = max([len(x) for x in batch])
  20. input_values = []
  21. attention_mask = []
  22. for x in batch:
  23. x_length = len(x)
  24. x = torch.nn.functional.pad(x, (0, max_length - x_length))
  25. mask = torch.ones_like(x)
  26. mask[x_length:] = 0
  27. input_values.append(x)
  28. attention_mask.append(mask)
  29. input_values = torch.stack(input_values)
  30. attention_mask = torch.stack(attention_mask)
  31. return {"input_values": input_values, "attention_mask": attention_mask}
  32. if __name__ == "__main__":
  33. import soundfile as sf
  34. from torch.utils.data import DataLoader
  35. from transformers import HubertForCTC, Wav2Vec2Processor
  36. dataset = HubertVQDataset("libritts-r.filelist")
  37. dataloader = DataLoader(
  38. dataset, batch_size=16, shuffle=True, collate_fn=HubertVQCollator()
  39. )
  40. hubert = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft")
  41. processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
  42. hubert.eval()
  43. for batch in dataloader:
  44. print(batch)
  45. logits = hubert(**batch).logits
  46. predicted_ids = torch.argmax(logits, dim=-1)
  47. transcription = processor.decode(predicted_ids[0])
  48. print(transcription)
  49. sf.write("test.wav", batch["input_values"][0].numpy(), 16000)
  50. break