inference.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # In[1]:
  4. import os
  5. os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
  6. # get_ipython().run_line_magic('matplotlib', 'inline')
  7. import matplotlib
  8. import matplotlib.pyplot as plt
  9. import IPython.display as ipd
  10. import sys
  11. sys.path.append('waveglow/')
  12. from itertools import cycle
  13. import numpy as np
  14. import scipy as sp
  15. from scipy.io.wavfile import write
  16. import pandas as pd
  17. import librosa
  18. import torch
  19. from hparams import create_hparams
  20. from model import Tacotron2, load_model
  21. from waveglow.denoiser import Denoiser
  22. from layers import TacotronSTFT
  23. from data_utils import TextMelLoader, TextMelCollate
  24. from text import cmudict, text_to_sequence
  25. from mellotron_utils import get_data_from_musicxml
  26. # In[2]:
  27. import torch
  28. # In[2]:
  29. def panner(signal, angle):
  30. angle = np.radians(angle)
  31. left = np.sqrt(2)/2.0 * (np.cos(angle) - np.sin(angle)) * signal
  32. right = np.sqrt(2)/2.0 * (np.cos(angle) + np.sin(angle)) * signal
  33. return np.dstack((left, right))[0]
  34. # In[3]:
  35. def plot_mel_f0_alignment(mel_source, mel_outputs_postnet, f0s, alignments, figsize=(16, 16)):
  36. fig, axes = plt.subplots(4, 1, figsize=figsize)
  37. axes = axes.flatten()
  38. axes[0].imshow(mel_source, aspect='auto', origin='bottom', interpolation='none')
  39. axes[1].imshow(mel_outputs_postnet, aspect='auto', origin='bottom', interpolation='none')
  40. axes[2].scatter(range(len(f0s)), f0s, alpha=0.5, color='red', marker='.', s=1)
  41. axes[2].set_xlim(0, len(f0s))
  42. axes[3].imshow(alignments, aspect='auto', origin='bottom', interpolation='none')
  43. axes[0].set_title("Source Mel")
  44. axes[1].set_title("Predicted Mel")
  45. axes[2].set_title("Source pitch contour")
  46. axes[3].set_title("Source rhythm")
  47. plt.tight_layout()
  48. # In[4]:
  49. def load_mel(path):
  50. audio, sampling_rate = librosa.core.load(path, sr=hparams.sampling_rate)
  51. audio = torch.from_numpy(audio)
  52. if sampling_rate != hparams.sampling_rate:
  53. raise ValueError("{} SR doesn't match target {} SR".format(
  54. sampling_rate, stft.sampling_rate))
  55. audio_norm = audio.unsqueeze(0)
  56. audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
  57. melspec = stft.mel_spectrogram(audio_norm)
  58. melspec = melspec.to(device)
  59. return melspec
  60. # In[5]:
  61. hparams = create_hparams()
  62. # In[6]:
  63. stft = TacotronSTFT(hparams.filter_length, hparams.hop_length, hparams.win_length,
  64. hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin,
  65. hparams.mel_fmax)
  66. # ## Load Models
  67. # In[7]:
  68. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  69. # In[8]:
  70. checkpoint_path = "models/mellotron_libritts.pt"
  71. # mellotron = load_model(hparams).cuda().eval()
  72. mellotron = load_model(hparams).to(device).eval()
  73. mellotron.load_state_dict(torch.load(checkpoint_path,map_location=torch.device('cpu'))['state_dict'])
  74. # In[9]:
  75. waveglow_path = 'models/waveglow_256channels_universal_v4.pt'
  76. waveglow = torch.load(waveglow_path,map_location=torch.device('cpu'))['model'].to(device).eval()
  77. denoiser = Denoiser(waveglow).to(device).eval()
  78. # ## Setup dataloaders
  79. # In[10]:
  80. arpabet_dict = cmudict.CMUDict('data/cmu_dictionary')
  81. audio_paths = 'data/examples_filelist.txt'
  82. dataloader = TextMelLoader(audio_paths, hparams)
  83. datacollate = TextMelCollate(1)
  84. # ## Load data
  85. # In[11]:
  86. file_idx = 0
  87. audio_path, text, sid = dataloader.audiopaths_and_text[file_idx]
  88. # get audio path, encoded text, pitch contour and mel for gst
  89. text_encoded = torch.LongTensor(text_to_sequence(text, hparams.text_cleaners, arpabet_dict))[None, :].to(device)
  90. pitch_contour = dataloader[file_idx][3][None].to(device)
  91. mel = load_mel(audio_path)
  92. print(audio_path, text)
  93. # load source data to obtain rhythm using tacotron 2 as a forced aligner
  94. # x, y = mellotron.parse_batch(datacollate([dataloader[file_idx]]))
  95. # In[12]:
  96. # In[14]:
  97. # ## Define Speakers Set
  98. # In[15]:
  99. speaker_ids = TextMelLoader("filelists/libritts_train_clean_100_audiopath_text_sid_shorterthan10s_atleast5min_train_filelist.txt", hparams).speaker_ids
  100. speakers = pd.read_csv('filelists/libritts_speakerinfo.txt', engine='python',header=None, comment=';', sep=' *\| *',
  101. names=['ID', 'SEX', 'SUBSET', 'MINUTES', 'NAME'])
  102. speakers['MELLOTRON_ID'] = speakers['ID'].apply(lambda x: speaker_ids[x] if x in speaker_ids else -1)
  103. female_speakers = cycle(
  104. speakers.query("SEX == 'F' and MINUTES > 20 and MELLOTRON_ID >= 0")['MELLOTRON_ID'].sample(frac=1).tolist())
  105. male_speakers = cycle(
  106. speakers.query("SEX == 'M' and MINUTES > 20 and MELLOTRON_ID >= 0")['MELLOTRON_ID'].sample(frac=1).tolist())
  107. # In[ ]:
  108. # # Style Transfer (Rhythm and Pitch Contour)
  109. # In[16]:
  110. # with torch.no_grad():
  111. # # get rhythm (alignment map) using tacotron 2
  112. # mel_outputs, mel_outputs_postnet, gate_outputs, rhythm = mellotron.forward(x)
  113. # rhythm = rhythm.permute(1, 0, 2)
  114. #
  115. #
  116. # # In[17]:
  117. #
  118. #
  119. # speaker_id = next(female_speakers) if np.random.randint(2) else next(male_speakers)
  120. # speaker_id = torch.LongTensor([speaker_id]).to(device)
  121. #
  122. # with torch.no_grad():
  123. # mel_outputs, mel_outputs_postnet, gate_outputs, _ = mellotron.inference_noattention(
  124. # (text_encoded, mel, speaker_id, pitch_contour, rhythm))
  125. #
  126. # plot_mel_f0_alignment(x[2].data.cpu().numpy()[0],
  127. # mel_outputs_postnet.data.cpu().numpy()[0],
  128. # pitch_contour.data.cpu().numpy()[0, 0],
  129. # rhythm.data.cpu().numpy()[:, 0].T)
  130. #
  131. #
  132. # # In[18]:
  133. #
  134. #
  135. # with torch.no_grad():
  136. # audio = denoiser(waveglow.infer(mel_outputs_postnet, sigma=0.8), 0.01)[:, 0]
  137. # ipd.Audio(audio[0].data.cpu().numpy(), rate=hparams.sampling_rate)
  138. # # Singing Voice from Music Score
  139. # In[49]:
  140. data = get_data_from_musicxml('data/haendel_hallelujah_1.musicxml', 132, convert_stress=True)
  141. # data = get_data_from_musicxml('data/Dream_a_little_dream_of_me.musicxml', 132, convert_stress=True)
  142. panning = {'Soprano': [-60, -30], 'Alto': [-40, -10], 'Tenor': [30, 60], 'Bass': [10, 40]}
  143. # In[ ]:
  144. torch.LongTensor([next(female_speakers)]).to(device)
  145. # In[18]:
  146. n_speakers_per_part = 4
  147. frequency_scaling = 0.4
  148. n_seconds = 90
  149. audio_stereo = np.zeros((hparams.sampling_rate*n_seconds, 2), dtype=np.float32)
  150. for i, (part, v) in enumerate(data.items()):
  151. rhythm = data[part]['rhythm'].to(device)
  152. pitch_contour = data[part]['pitch_contour'].to(device)
  153. text_encoded = data[part]['text_encoded'].to(device)
  154. for k in range(n_speakers_per_part):
  155. pan = np.random.randint(panning[part][0], panning[part][1])
  156. if any(x in part.lower() for x in ('soprano', 'alto', 'female')):
  157. speaker_id = torch.LongTensor([next(female_speakers)]).to(device)
  158. else:
  159. speaker_id = torch.LongTensor([next(male_speakers)]).to(device)
  160. print("{} MellotronID {} pan {}".format(part, speaker_id.item(), pan))
  161. with torch.no_grad():
  162. mel_outputs, mel_outputs_postnet, gate_outputs, alignments_transfer = mellotron.inference_noattention(
  163. (text_encoded, mel, speaker_id, pitch_contour*frequency_scaling, rhythm))
  164. audio = denoiser(waveglow.infer(mel_outputs_postnet, sigma=0.8), 0.01)[0, 0]
  165. audio = audio.cpu().numpy()
  166. audio = panner(audio, pan)
  167. audio_stereo[:audio.shape[0]] += audio
  168. write("{} {}.wav".format(part, speaker_id.item()), hparams.sampling_rate, audio)
  169. # In[19]:
  170. audio_stereo = audio_stereo / np.max(np.abs(audio_stereo))
  171. write("audio_stereo.wav", hparams.sampling_rate, audio_stereo)
  172. ipd.Audio([audio_stereo[:,0], audio_stereo[:,1]], rate=hparams.sampling_rate)
  173. # In[20]:
  174. # mellotron.inference_noattention(text_encoded)
  175. # In[21]: