data_utils.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import random
  2. import os
  3. import re
  4. import numpy as np
  5. import torch
  6. import torch.utils.data
  7. import librosa
  8. import layers
  9. from utils import load_wav_to_torch, load_filepaths_and_text
  10. from text import text_to_sequence, cmudict
  11. from yin import compute_yin
  12. class TextMelLoader(torch.utils.data.Dataset):
  13. """
  14. 1) loads audio, text and speaker ids
  15. 2) normalizes text and converts them to sequences of one-hot vectors
  16. 3) computes mel-spectrograms and f0s from audio files.
  17. """
  18. def __init__(self, audiopaths_and_text, hparams, speaker_ids=None):
  19. self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
  20. self.text_cleaners = hparams.text_cleaners
  21. self.max_wav_value = hparams.max_wav_value
  22. self.sampling_rate = hparams.sampling_rate
  23. self.stft = layers.TacotronSTFT(
  24. hparams.filter_length, hparams.hop_length, hparams.win_length,
  25. hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin,
  26. hparams.mel_fmax)
  27. self.sampling_rate = hparams.sampling_rate
  28. self.filter_length = hparams.filter_length
  29. self.hop_length = hparams.hop_length
  30. self.f0_min = hparams.f0_min
  31. self.f0_max = hparams.f0_max
  32. self.harm_thresh = hparams.harm_thresh
  33. self.p_arpabet = hparams.p_arpabet
  34. self.cmudict = None
  35. if hparams.cmudict_path is not None:
  36. self.cmudict = cmudict.CMUDict(hparams.cmudict_path)
  37. self.speaker_ids = speaker_ids
  38. if speaker_ids is None:
  39. self.speaker_ids = self.create_speaker_lookup_table(
  40. self.audiopaths_and_text)
  41. random.seed(1234)
  42. random.shuffle(self.audiopaths_and_text)
  43. def create_speaker_lookup_table(self, audiopaths_and_text):
  44. speaker_ids = np.sort(np.unique([x[2] for x in audiopaths_and_text]))
  45. d = {int(speaker_ids[i]): i for i in range(len(speaker_ids))}
  46. return d
  47. # sampling_rate = 22050
  48. def get_f0(self, audio, sampling_rate=16000, frame_length=1024,
  49. hop_length=256, f0_min=100, f0_max=300, harm_thresh=0.1):
  50. f0, harmonic_rates, argmins, times = compute_yin(
  51. audio, sampling_rate, frame_length, hop_length, f0_min, f0_max,
  52. harm_thresh)
  53. pad = int((frame_length / hop_length) / 2)
  54. f0 = [0.0] * pad + f0 + [0.0] * pad
  55. f0 = np.array(f0, dtype=np.float32)
  56. return f0
  57. def get_data(self, audiopath_and_text):
  58. audiopath, text, speaker = audiopath_and_text
  59. text = self.get_text(text)
  60. mel, f0 = self.get_mel_and_f0(audiopath)
  61. speaker_id = self.get_speaker_id(speaker)
  62. return (text, mel, speaker_id, f0)
  63. def get_speaker_id(self, speaker_id):
  64. return torch.IntTensor([self.speaker_ids[int(speaker_id)]])
  65. def get_mel_and_f0(self, filepath):
  66. audio, sampling_rate = load_wav_to_torch(filepath)
  67. if sampling_rate != self.stft.sampling_rate:
  68. raise ValueError("{} SR doesn't match target {} SR".format(
  69. sampling_rate, self.stft.sampling_rate))
  70. audio_norm = audio / self.max_wav_value
  71. audio_norm = audio_norm.unsqueeze(0)
  72. melspec = self.stft.mel_spectrogram(audio_norm)
  73. melspec = torch.squeeze(melspec, 0)
  74. f0 = self.get_f0(audio.cpu().numpy(), self.sampling_rate,
  75. self.filter_length, self.hop_length, self.f0_min,
  76. self.f0_max, self.harm_thresh)
  77. f0 = torch.from_numpy(f0)[None]
  78. f0 = f0[:, :melspec.size(1)]
  79. return melspec, f0
  80. def get_text(self, text):
  81. text_norm = torch.IntTensor(
  82. text_to_sequence(text, self.text_cleaners, self.cmudict, self.p_arpabet))
  83. return text_norm
  84. def __getitem__(self, index):
  85. return self.get_data(self.audiopaths_and_text[index])
  86. def __len__(self):
  87. return len(self.audiopaths_and_text)
  88. class TextMelCollate():
  89. """ Zero-pads model inputs and targets based on number of frames per setep
  90. """
  91. def __init__(self, n_frames_per_step):
  92. self.n_frames_per_step = n_frames_per_step
  93. def __call__(self, batch):
  94. """Collate's training batch from normalized text and mel-spectrogram
  95. PARAMS
  96. ------
  97. batch: [text_normalized, mel_normalized]
  98. """
  99. # Right zero-pad all one-hot text sequences to max input length
  100. input_lengths, ids_sorted_decreasing = torch.sort(
  101. torch.LongTensor([len(x[0]) for x in batch]),
  102. dim=0, descending=True)
  103. max_input_len = input_lengths[0]
  104. text_padded = torch.LongTensor(len(batch), max_input_len)
  105. text_padded.zero_()
  106. for i in range(len(ids_sorted_decreasing)):
  107. text = batch[ids_sorted_decreasing[i]][0]
  108. text_padded[i, :text.size(0)] = text
  109. # Right zero-pad mel-spec
  110. num_mels = batch[0][1].size(0)
  111. max_target_len = max([x[1].size(1) for x in batch])
  112. if max_target_len % self.n_frames_per_step != 0:
  113. max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step
  114. assert max_target_len % self.n_frames_per_step == 0
  115. # include mel padded, gate padded and speaker ids
  116. mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
  117. mel_padded.zero_()
  118. gate_padded = torch.FloatTensor(len(batch), max_target_len)
  119. gate_padded.zero_()
  120. output_lengths = torch.LongTensor(len(batch))
  121. speaker_ids = torch.LongTensor(len(batch))
  122. f0_padded = torch.FloatTensor(len(batch), 1, max_target_len)
  123. f0_padded.zero_()
  124. for i in range(len(ids_sorted_decreasing)):
  125. mel = batch[ids_sorted_decreasing[i]][1]
  126. mel_padded[i, :, :mel.size(1)] = mel
  127. gate_padded[i, mel.size(1)-1:] = 1
  128. output_lengths[i] = mel.size(1)
  129. speaker_ids[i] = batch[ids_sorted_decreasing[i]][2]
  130. f0 = batch[ids_sorted_decreasing[i]][3]
  131. f0_padded[i, :, :f0.size(1)] = f0
  132. model_inputs = (text_padded, input_lengths, mel_padded, gate_padded,
  133. output_lengths, speaker_ids, f0_padded)
  134. return model_inputs