| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283 |
- import matplotlib
- import torch
- from matplotlib import pyplot as plt
- matplotlib.use("Agg")
- def convert_pad_shape(pad_shape):
- l = pad_shape[::-1]
- pad_shape = [item for sublist in l for item in sublist]
- return pad_shape
- def sequence_mask(length, max_length=None):
- if max_length is None:
- max_length = length.max()
- x = torch.arange(max_length, dtype=length.dtype, device=length.device)
- return x.unsqueeze(0) < length.unsqueeze(1)
- def init_weights(m, mean=0.0, std=0.01):
- classname = m.__class__.__name__
- if classname.find("Conv") != -1:
- m.weight.data.normal_(mean, std)
- def get_padding(kernel_size, dilation=1):
- return int((kernel_size * dilation - dilation) / 2)
- def plot_mel(data, titles=None):
- fig, axes = plt.subplots(len(data), 1, squeeze=False)
- if titles is None:
- titles = [None for i in range(len(data))]
- plt.tight_layout()
- for i in range(len(data)):
- mel = data[i]
- if isinstance(mel, torch.Tensor):
- mel = mel.detach().cpu().numpy()
- axes[i][0].imshow(mel, origin="lower")
- axes[i][0].set_aspect(2.5, adjustable="box")
- axes[i][0].set_ylim(0, mel.shape[0])
- axes[i][0].set_title(titles[i], fontsize="medium")
- axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
- axes[i][0].set_anchor("W")
- return fig
- def slice_segments(x, ids_str, segment_size=4):
- ret = torch.zeros_like(x[:, :, :segment_size])
- for i in range(x.size(0)):
- idx_str = ids_str[i]
- idx_end = idx_str + segment_size
- ret[i] = x[i, :, idx_str:idx_end]
- return ret
- def rand_slice_segments(x, x_lengths=None, segment_size=4):
- b, d, t = x.size()
- if x_lengths is None:
- x_lengths = t
- ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
- ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
- ret = slice_segments(x, ids_str, segment_size)
- return ret, ids_str
- @torch.jit.script
- def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
- n_channels_int = n_channels[0]
- in_act = input_a + input_b
- t_act = torch.tanh(in_act[:, :n_channels_int, :])
- s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
- acts = t_act * s_act
- return acts
|