viz.py 767 B

1234567891011121314151617181920212223242526272829
  1. import matplotlib
  2. from matplotlib import pyplot as plt
  3. from torch import Tensor
  4. matplotlib.use("Agg")
  5. def plot_mel(data, titles=None):
  6. fig, axes = plt.subplots(len(data), 1, squeeze=False)
  7. if titles is None:
  8. titles = [None for i in range(len(data))]
  9. plt.tight_layout()
  10. for i in range(len(data)):
  11. mel = data[i]
  12. if isinstance(mel, Tensor):
  13. mel = mel.detach().cpu().numpy()
  14. axes[i][0].imshow(mel, origin="lower")
  15. axes[i][0].set_aspect(2.5, adjustable="box")
  16. axes[i][0].set_ylim(0, mel.shape[0])
  17. axes[i][0].set_title(titles[i], fontsize="medium")
  18. axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
  19. axes[i][0].set_anchor("W")
  20. return fig