| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
- # LICENSE is in incl_licenses directory.
- import glob
- import os
- import matplotlib
- import torch
- from torch.nn.utils import weight_norm
- matplotlib.use("Agg")
- import matplotlib.pylab as plt
- from scipy.io.wavfile import write
- def plot_spectrogram(spectrogram):
- fig, ax = plt.subplots(figsize=(10, 2))
- im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
- plt.colorbar(im, ax=ax)
- fig.canvas.draw()
- plt.close()
- return fig
- def plot_spectrogram_clipped(spectrogram, clip_max=2.0):
- fig, ax = plt.subplots(figsize=(10, 2))
- im = ax.imshow(
- spectrogram,
- aspect="auto",
- origin="lower",
- interpolation="none",
- vmin=1e-6,
- vmax=clip_max,
- )
- plt.colorbar(im, ax=ax)
- fig.canvas.draw()
- plt.close()
- return fig
- def init_weights(m, mean=0.0, std=0.01):
- classname = m.__class__.__name__
- if classname.find("Conv") != -1:
- m.weight.data.normal_(mean, std)
- def apply_weight_norm(m):
- classname = m.__class__.__name__
- if classname.find("Conv") != -1:
- weight_norm(m)
- def get_padding(kernel_size, dilation=1):
- return int((kernel_size * dilation - dilation) / 2)
- def load_checkpoint(filepath, device):
- assert os.path.isfile(filepath)
- print("Loading '{}'".format(filepath))
- checkpoint_dict = torch.load(filepath, map_location=device)
- print("Complete.")
- return checkpoint_dict
- def save_checkpoint(filepath, obj):
- print("Saving checkpoint to {}".format(filepath))
- torch.save(obj, filepath)
- print("Complete.")
- def scan_checkpoint(cp_dir, prefix):
- pattern = os.path.join(cp_dir, prefix + "????????")
- cp_list = glob.glob(pattern)
- if len(cp_list) == 0:
- return None
- return sorted(cp_list)[-1]
|