utils.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import matplotlib
  2. import torch
  3. from matplotlib import pyplot as plt
  4. matplotlib.use("Agg")
  5. def convert_pad_shape(pad_shape):
  6. l = pad_shape[::-1]
  7. pad_shape = [item for sublist in l for item in sublist]
  8. return pad_shape
  9. def sequence_mask(length, max_length=None):
  10. if max_length is None:
  11. max_length = length.max()
  12. x = torch.arange(max_length, dtype=length.dtype, device=length.device)
  13. return x.unsqueeze(0) < length.unsqueeze(1)
  14. def init_weights(m, mean=0.0, std=0.01):
  15. classname = m.__class__.__name__
  16. if classname.find("Conv") != -1:
  17. m.weight.data.normal_(mean, std)
  18. def get_padding(kernel_size, dilation=1):
  19. return int((kernel_size * dilation - dilation) / 2)
  20. def plot_mel(data, titles=None):
  21. fig, axes = plt.subplots(len(data), 1, squeeze=False)
  22. if titles is None:
  23. titles = [None for i in range(len(data))]
  24. plt.tight_layout()
  25. for i in range(len(data)):
  26. mel = data[i]
  27. if isinstance(mel, torch.Tensor):
  28. mel = mel.detach().cpu().numpy()
  29. axes[i][0].imshow(mel, origin="lower")
  30. axes[i][0].set_aspect(2.5, adjustable="box")
  31. axes[i][0].set_ylim(0, mel.shape[0])
  32. axes[i][0].set_title(titles[i], fontsize="medium")
  33. axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
  34. axes[i][0].set_anchor("W")
  35. return fig
  36. def slice_segments(x, ids_str, segment_size=4):
  37. ret = torch.zeros_like(x[:, :, :segment_size])
  38. for i in range(x.size(0)):
  39. idx_str = ids_str[i]
  40. idx_end = idx_str + segment_size
  41. ret[i] = x[i, :, idx_str:idx_end]
  42. return ret
  43. def rand_slice_segments(x, x_lengths=None, segment_size=4):
  44. b, d, t = x.size()
  45. if x_lengths is None:
  46. x_lengths = t
  47. ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
  48. ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
  49. ret = slice_segments(x, ids_str, segment_size)
  50. return ret, ids_str
  51. @torch.jit.script
  52. def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
  53. n_channels_int = n_channels[0]
  54. in_act = input_a + input_b
  55. t_act = torch.tanh(in_act[:, :n_channels_int, :])
  56. s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
  57. acts = t_act * s_act
  58. return acts