utils.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
  2. # LICENSE is in incl_licenses directory.
  3. import glob
  4. import os
  5. import matplotlib
  6. import torch
  7. from torch.nn.utils import weight_norm
  8. matplotlib.use("Agg")
  9. import matplotlib.pylab as plt
  10. from scipy.io.wavfile import write
  11. def plot_spectrogram(spectrogram):
  12. fig, ax = plt.subplots(figsize=(10, 2))
  13. im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
  14. plt.colorbar(im, ax=ax)
  15. fig.canvas.draw()
  16. plt.close()
  17. return fig
  18. def plot_spectrogram_clipped(spectrogram, clip_max=2.0):
  19. fig, ax = plt.subplots(figsize=(10, 2))
  20. im = ax.imshow(
  21. spectrogram,
  22. aspect="auto",
  23. origin="lower",
  24. interpolation="none",
  25. vmin=1e-6,
  26. vmax=clip_max,
  27. )
  28. plt.colorbar(im, ax=ax)
  29. fig.canvas.draw()
  30. plt.close()
  31. return fig
  32. def init_weights(m, mean=0.0, std=0.01):
  33. classname = m.__class__.__name__
  34. if classname.find("Conv") != -1:
  35. m.weight.data.normal_(mean, std)
  36. def apply_weight_norm(m):
  37. classname = m.__class__.__name__
  38. if classname.find("Conv") != -1:
  39. weight_norm(m)
  40. def get_padding(kernel_size, dilation=1):
  41. return int((kernel_size * dilation - dilation) / 2)
  42. def load_checkpoint(filepath, device):
  43. assert os.path.isfile(filepath)
  44. print("Loading '{}'".format(filepath))
  45. checkpoint_dict = torch.load(filepath, map_location=device)
  46. print("Complete.")
  47. return checkpoint_dict
  48. def save_checkpoint(filepath, obj):
  49. print("Saving checkpoint to {}".format(filepath))
  50. torch.save(obj, filepath)
  51. print("Complete.")
  52. def scan_checkpoint(cp_dir, prefix):
  53. pattern = os.path.join(cp_dir, prefix + "????????")
  54. cp_list = glob.glob(pattern)
  55. if len(cp_list) == 0:
  56. return None
  57. return sorted(cp_list)[-1]