#!/usr/bin/env python # coding: utf-8 # In[1]: import os os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" # get_ipython().run_line_magic('matplotlib', 'inline') import matplotlib import matplotlib.pyplot as plt import IPython.display as ipd import sys sys.path.append('waveglow/') from itertools import cycle import numpy as np import scipy as sp from scipy.io.wavfile import write import pandas as pd import librosa import torch from hparams import create_hparams from model import Tacotron2, load_model from waveglow.denoiser import Denoiser from layers import TacotronSTFT from data_utils import TextMelLoader, TextMelCollate from text import cmudict, text_to_sequence from mellotron_utils import get_data_from_musicxml # In[2]: import torch # In[2]: def panner(signal, angle): angle = np.radians(angle) left = np.sqrt(2)/2.0 * (np.cos(angle) - np.sin(angle)) * signal right = np.sqrt(2)/2.0 * (np.cos(angle) + np.sin(angle)) * signal return np.dstack((left, right))[0] # In[3]: def plot_mel_f0_alignment(mel_source, mel_outputs_postnet, f0s, alignments, figsize=(16, 16)): fig, axes = plt.subplots(4, 1, figsize=figsize) axes = axes.flatten() axes[0].imshow(mel_source, aspect='auto', origin='bottom', interpolation='none') axes[1].imshow(mel_outputs_postnet, aspect='auto', origin='bottom', interpolation='none') axes[2].scatter(range(len(f0s)), f0s, alpha=0.5, color='red', marker='.', s=1) axes[2].set_xlim(0, len(f0s)) axes[3].imshow(alignments, aspect='auto', origin='bottom', interpolation='none') axes[0].set_title("Source Mel") axes[1].set_title("Predicted Mel") axes[2].set_title("Source pitch contour") axes[3].set_title("Source rhythm") plt.tight_layout() # In[4]: def load_mel(path): audio, sampling_rate = librosa.core.load(path, sr=hparams.sampling_rate) audio = torch.from_numpy(audio) if sampling_rate != hparams.sampling_rate: raise ValueError("{} SR doesn't match target {} SR".format( sampling_rate, stft.sampling_rate)) audio_norm = audio.unsqueeze(0) audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) melspec = stft.mel_spectrogram(audio_norm) melspec = melspec.to(device) return melspec # In[5]: hparams = create_hparams() # In[6]: stft = TacotronSTFT(hparams.filter_length, hparams.hop_length, hparams.win_length, hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin, hparams.mel_fmax) # ## Load Models # In[7]: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # In[8]: checkpoint_path = "models/mellotron_libritts.pt" # mellotron = load_model(hparams).cuda().eval() mellotron = load_model(hparams).to(device).eval() mellotron.load_state_dict(torch.load(checkpoint_path,map_location=torch.device('cpu'))['state_dict']) # In[9]: waveglow_path = 'models/waveglow_256channels_universal_v4.pt' waveglow = torch.load(waveglow_path,map_location=torch.device('cpu'))['model'].to(device).eval() denoiser = Denoiser(waveglow).to(device).eval() # ## Setup dataloaders # In[10]: arpabet_dict = cmudict.CMUDict('data/cmu_dictionary') audio_paths = 'data/examples_filelist.txt' dataloader = TextMelLoader(audio_paths, hparams) datacollate = TextMelCollate(1) # ## Load data # In[11]: file_idx = 0 audio_path, text, sid = dataloader.audiopaths_and_text[file_idx] # get audio path, encoded text, pitch contour and mel for gst text_encoded = torch.LongTensor(text_to_sequence(text, hparams.text_cleaners, arpabet_dict))[None, :].to(device) pitch_contour = dataloader[file_idx][3][None].to(device) mel = load_mel(audio_path) print(audio_path, text) # load source data to obtain rhythm using tacotron 2 as a forced aligner # x, y = mellotron.parse_batch(datacollate([dataloader[file_idx]])) # In[12]: # In[14]: # ## Define Speakers Set # In[15]: speaker_ids = TextMelLoader("filelists/libritts_train_clean_100_audiopath_text_sid_shorterthan10s_atleast5min_train_filelist.txt", hparams).speaker_ids speakers = pd.read_csv('filelists/libritts_speakerinfo.txt', engine='python',header=None, comment=';', sep=' *\| *', names=['ID', 'SEX', 'SUBSET', 'MINUTES', 'NAME']) speakers['MELLOTRON_ID'] = speakers['ID'].apply(lambda x: speaker_ids[x] if x in speaker_ids else -1) female_speakers = cycle( speakers.query("SEX == 'F' and MINUTES > 20 and MELLOTRON_ID >= 0")['MELLOTRON_ID'].sample(frac=1).tolist()) male_speakers = cycle( speakers.query("SEX == 'M' and MINUTES > 20 and MELLOTRON_ID >= 0")['MELLOTRON_ID'].sample(frac=1).tolist()) # In[ ]: # # Style Transfer (Rhythm and Pitch Contour) # In[16]: # with torch.no_grad(): # # get rhythm (alignment map) using tacotron 2 # mel_outputs, mel_outputs_postnet, gate_outputs, rhythm = mellotron.forward(x) # rhythm = rhythm.permute(1, 0, 2) # # # # In[17]: # # # speaker_id = next(female_speakers) if np.random.randint(2) else next(male_speakers) # speaker_id = torch.LongTensor([speaker_id]).to(device) # # with torch.no_grad(): # mel_outputs, mel_outputs_postnet, gate_outputs, _ = mellotron.inference_noattention( # (text_encoded, mel, speaker_id, pitch_contour, rhythm)) # # plot_mel_f0_alignment(x[2].data.cpu().numpy()[0], # mel_outputs_postnet.data.cpu().numpy()[0], # pitch_contour.data.cpu().numpy()[0, 0], # rhythm.data.cpu().numpy()[:, 0].T) # # # # In[18]: # # # with torch.no_grad(): # audio = denoiser(waveglow.infer(mel_outputs_postnet, sigma=0.8), 0.01)[:, 0] # ipd.Audio(audio[0].data.cpu().numpy(), rate=hparams.sampling_rate) # # Singing Voice from Music Score # In[49]: data = get_data_from_musicxml('data/haendel_hallelujah_1.musicxml', 132, convert_stress=True) # data = get_data_from_musicxml('data/Dream_a_little_dream_of_me.musicxml', 132, convert_stress=True) panning = {'Soprano': [-60, -30], 'Alto': [-40, -10], 'Tenor': [30, 60], 'Bass': [10, 40]} # In[ ]: torch.LongTensor([next(female_speakers)]).to(device) # In[18]: n_speakers_per_part = 4 frequency_scaling = 0.4 n_seconds = 90 audio_stereo = np.zeros((hparams.sampling_rate*n_seconds, 2), dtype=np.float32) for i, (part, v) in enumerate(data.items()): rhythm = data[part]['rhythm'].to(device) pitch_contour = data[part]['pitch_contour'].to(device) text_encoded = data[part]['text_encoded'].to(device) for k in range(n_speakers_per_part): pan = np.random.randint(panning[part][0], panning[part][1]) if any(x in part.lower() for x in ('soprano', 'alto', 'female')): speaker_id = torch.LongTensor([next(female_speakers)]).to(device) else: speaker_id = torch.LongTensor([next(male_speakers)]).to(device) print("{} MellotronID {} pan {}".format(part, speaker_id.item(), pan)) with torch.no_grad(): mel_outputs, mel_outputs_postnet, gate_outputs, alignments_transfer = mellotron.inference_noattention( (text_encoded, mel, speaker_id, pitch_contour*frequency_scaling, rhythm)) audio = denoiser(waveglow.infer(mel_outputs_postnet, sigma=0.8), 0.01)[0, 0] audio = audio.cpu().numpy() audio = panner(audio, pan) audio_stereo[:audio.shape[0]] += audio write("{} {}.wav".format(part, speaker_id.item()), hparams.sampling_rate, audio) # In[19]: audio_stereo = audio_stereo / np.max(np.abs(audio_stereo)) write("audio_stereo.wav", hparams.sampling_rate, audio_stereo) ipd.Audio([audio_stereo[:,0], audio_stereo[:,1]], rate=hparams.sampling_rate) # In[20]: # mellotron.inference_noattention(text_encoded) # In[21]: