Lengyue 2 年 前
コミット
a12a41d831

+ 3 - 0
fish_speech/models/vq_diffusion/bigvgan/__init__.py

@@ -0,0 +1,3 @@
+from .bigvgan import BigVGAN
+
+__all__ = ["BigVGAN"]

+ 126 - 0
fish_speech/models/vq_diffusion/bigvgan/activations.py

@@ -0,0 +1,126 @@
+# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
+#   LICENSE is in incl_licenses directory.
+
+import torch
+from torch import nn, pow, sin
+from torch.nn import Parameter
+
+
+class Snake(nn.Module):
+    """
+    Implementation of a sine-based periodic activation function
+    Shape:
+        - Input: (B, C, T)
+        - Output: (B, C, T), same shape as the input
+    Parameters:
+        - alpha - trainable parameter
+    References:
+        - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
+        https://arxiv.org/abs/2006.08195
+    Examples:
+        >>> a1 = snake(256)
+        >>> x = torch.randn(256)
+        >>> x = a1(x)
+    """
+
+    def __init__(
+        self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
+    ):
+        """
+        Initialization.
+        INPUT:
+            - in_features: shape of the input
+            - alpha: trainable parameter
+            alpha is initialized to 1 by default, higher values = higher-frequency.
+            alpha will be trained along with the rest of your model.
+        """
+        super(Snake, self).__init__()
+        self.in_features = in_features
+
+        # initialize alpha
+        self.alpha_logscale = alpha_logscale
+        if self.alpha_logscale:  # log scale alphas initialized to zeros
+            self.alpha = Parameter(torch.zeros(in_features) * alpha)
+        else:  # linear scale alphas initialized to ones
+            self.alpha = Parameter(torch.ones(in_features) * alpha)
+
+        self.alpha.requires_grad = alpha_trainable
+
+        self.no_div_by_zero = 0.000000001
+
+    def forward(self, x):
+        """
+        Forward pass of the function.
+        Applies the function to the input elementwise.
+        Snake ∶= x + 1/a * sin^2 (xa)
+        """
+        alpha = self.alpha.unsqueeze(0).unsqueeze(-1)  # line up with x to [B, C, T]
+        if self.alpha_logscale:
+            alpha = torch.exp(alpha)
+        x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
+
+        return x
+
+
+class SnakeBeta(nn.Module):
+    """
+    A modified Snake function which uses separate parameters for the magnitude of the periodic components
+    Shape:
+        - Input: (B, C, T)
+        - Output: (B, C, T), same shape as the input
+    Parameters:
+        - alpha - trainable parameter that controls frequency
+        - beta - trainable parameter that controls magnitude
+    References:
+        - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
+        https://arxiv.org/abs/2006.08195
+    Examples:
+        >>> a1 = snakebeta(256)
+        >>> x = torch.randn(256)
+        >>> x = a1(x)
+    """
+
+    def __init__(
+        self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
+    ):
+        """
+        Initialization.
+        INPUT:
+            - in_features: shape of the input
+            - alpha - trainable parameter that controls frequency
+            - beta - trainable parameter that controls magnitude
+            alpha is initialized to 1 by default, higher values = higher-frequency.
+            beta is initialized to 1 by default, higher values = higher-magnitude.
+            alpha will be trained along with the rest of your model.
+        """
+        super(SnakeBeta, self).__init__()
+        self.in_features = in_features
+
+        # initialize alpha
+        self.alpha_logscale = alpha_logscale
+        if self.alpha_logscale:  # log scale alphas initialized to zeros
+            self.alpha = Parameter(torch.zeros(in_features) * alpha)
+            self.beta = Parameter(torch.zeros(in_features) * alpha)
+        else:  # linear scale alphas initialized to ones
+            self.alpha = Parameter(torch.ones(in_features) * alpha)
+            self.beta = Parameter(torch.ones(in_features) * alpha)
+
+        self.alpha.requires_grad = alpha_trainable
+        self.beta.requires_grad = alpha_trainable
+
+        self.no_div_by_zero = 0.000000001
+
+    def forward(self, x):
+        """
+        Forward pass of the function.
+        Applies the function to the input elementwise.
+        SnakeBeta ∶= x + 1/b * sin^2 (xa)
+        """
+        alpha = self.alpha.unsqueeze(0).unsqueeze(-1)  # line up with x to [B, C, T]
+        beta = self.beta.unsqueeze(0).unsqueeze(-1)
+        if self.alpha_logscale:
+            alpha = torch.exp(alpha)
+            beta = torch.exp(beta)
+        x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
+
+        return x

+ 6 - 0
fish_speech/models/vq_diffusion/bigvgan/alias_free_torch/__init__.py

@@ -0,0 +1,6 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+#   LICENSE is in incl_licenses directory.
+
+from .act import *
+from .filter import *
+from .resample import *

+ 31 - 0
fish_speech/models/vq_diffusion/bigvgan/alias_free_torch/act.py

@@ -0,0 +1,31 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+#   LICENSE is in incl_licenses directory.
+
+import torch.nn as nn
+
+from .resample import DownSample1d, UpSample1d
+
+
+class Activation1d(nn.Module):
+    def __init__(
+        self,
+        activation,
+        up_ratio: int = 2,
+        down_ratio: int = 2,
+        up_kernel_size: int = 12,
+        down_kernel_size: int = 12,
+    ):
+        super().__init__()
+        self.up_ratio = up_ratio
+        self.down_ratio = down_ratio
+        self.act = activation
+        self.upsample = UpSample1d(up_ratio, up_kernel_size)
+        self.downsample = DownSample1d(down_ratio, down_kernel_size)
+
+    # x: [B,C,T]
+    def forward(self, x):
+        x = self.upsample(x)
+        x = self.act(x)
+        x = self.downsample(x)
+
+        return x

+ 100 - 0
fish_speech/models/vq_diffusion/bigvgan/alias_free_torch/filter.py

@@ -0,0 +1,100 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+#   LICENSE is in incl_licenses directory.
+
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+if "sinc" in dir(torch):
+    sinc = torch.sinc
+else:
+    # This code is adopted from adefossez's julius.core.sinc under the MIT License
+    # https://adefossez.github.io/julius/julius/core.html
+    #   LICENSE is in incl_licenses directory.
+    def sinc(x: torch.Tensor):
+        """
+        Implementation of sinc, i.e. sin(pi * x) / (pi * x)
+        __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
+        """
+        return torch.where(
+            x == 0,
+            torch.tensor(1.0, device=x.device, dtype=x.dtype),
+            torch.sin(math.pi * x) / math.pi / x,
+        )
+
+
+# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
+# https://adefossez.github.io/julius/julius/lowpass.html
+#   LICENSE is in incl_licenses directory.
+def kaiser_sinc_filter1d(
+    cutoff, half_width, kernel_size
+):  # return filter [1,1,kernel_size]
+    even = kernel_size % 2 == 0
+    half_size = kernel_size // 2
+
+    # For kaiser window
+    delta_f = 4 * half_width
+    A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
+    if A > 50.0:
+        beta = 0.1102 * (A - 8.7)
+    elif A >= 21.0:
+        beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
+    else:
+        beta = 0.0
+    window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
+
+    # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
+    if even:
+        time = torch.arange(-half_size, half_size) + 0.5
+    else:
+        time = torch.arange(kernel_size) - half_size
+    if cutoff == 0:
+        filter_ = torch.zeros_like(time)
+    else:
+        filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
+        # Normalize filter to have sum = 1, otherwise we will have a small leakage
+        # of the constant component in the input signal.
+        filter_ /= filter_.sum()
+        filter = filter_.view(1, 1, kernel_size)
+
+    return filter
+
+
+class LowPassFilter1d(nn.Module):
+    def __init__(
+        self,
+        cutoff=0.5,
+        half_width=0.6,
+        stride: int = 1,
+        padding: bool = True,
+        padding_mode: str = "replicate",
+        kernel_size: int = 12,
+    ):
+        # kernel_size should be even number for stylegan3 setup,
+        # in this implementation, odd number is also possible.
+        super().__init__()
+        if cutoff < -0.0:
+            raise ValueError("Minimum cutoff must be larger than zero.")
+        if cutoff > 0.5:
+            raise ValueError("A cutoff above 0.5 does not make sense.")
+        self.kernel_size = kernel_size
+        self.even = kernel_size % 2 == 0
+        self.pad_left = kernel_size // 2 - int(self.even)
+        self.pad_right = kernel_size // 2
+        self.stride = stride
+        self.padding = padding
+        self.padding_mode = padding_mode
+        filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
+        self.register_buffer("filter", filter)
+
+    # input [B, C, T]
+    def forward(self, x):
+        _, C, _ = x.shape
+
+        if self.padding:
+            x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
+        out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
+
+        return out

+ 58 - 0
fish_speech/models/vq_diffusion/bigvgan/alias_free_torch/resample.py

@@ -0,0 +1,58 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+#   LICENSE is in incl_licenses directory.
+
+import torch.nn as nn
+from torch.nn import functional as F
+
+from .filter import LowPassFilter1d, kaiser_sinc_filter1d
+
+
+class UpSample1d(nn.Module):
+    def __init__(self, ratio=2, kernel_size=None):
+        super().__init__()
+        self.ratio = ratio
+        self.kernel_size = (
+            int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
+        )
+        self.stride = ratio
+        self.pad = self.kernel_size // ratio - 1
+        self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
+        self.pad_right = (
+            self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
+        )
+        filter = kaiser_sinc_filter1d(
+            cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
+        )
+        self.register_buffer("filter", filter)
+
+    # x: [B, C, T]
+    def forward(self, x):
+        _, C, _ = x.shape
+
+        x = F.pad(x, (self.pad, self.pad), mode="replicate")
+        x = self.ratio * F.conv_transpose1d(
+            x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
+        )
+        x = x[..., self.pad_left : -self.pad_right]
+
+        return x
+
+
+class DownSample1d(nn.Module):
+    def __init__(self, ratio=2, kernel_size=None):
+        super().__init__()
+        self.ratio = ratio
+        self.kernel_size = (
+            int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
+        )
+        self.lowpass = LowPassFilter1d(
+            cutoff=0.5 / ratio,
+            half_width=0.6 / ratio,
+            stride=ratio,
+            kernel_size=self.kernel_size,
+        )
+
+    def forward(self, x):
+        xx = self.lowpass(x)
+
+        return xx

+ 389 - 0
fish_speech/models/vq_diffusion/bigvgan/bigvgan.py

@@ -0,0 +1,389 @@
+# Copyright (c) 2022 NVIDIA CORPORATION.
+#   Licensed under the MIT license.
+
+# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
+#   LICENSE is in incl_licenses directory.
+
+
+import json
+from pathlib import Path
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import Conv1d, Conv2d, ConvTranspose1d
+from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
+
+from fish_speech.models.vq_diffusion.bigvgan.activations import Snake, SnakeBeta
+from fish_speech.models.vq_diffusion.bigvgan.alias_free_torch import Activation1d
+from fish_speech.models.vq_diffusion.bigvgan.utils import get_padding, init_weights
+from fish_speech.models.vqgan.spectrogram import LogMelSpectrogram
+
+LRELU_SLOPE = 0.1
+
+
+class AttrDict(dict):
+    def __init__(self, *args, **kwargs):
+        super(AttrDict, self).__init__(*args, **kwargs)
+        self.__dict__ = self
+
+
+class AMPBlock1(torch.nn.Module):
+    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
+        super(AMPBlock1, self).__init__()
+        self.h = h
+
+        self.convs1 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[2],
+                        padding=get_padding(kernel_size, dilation[2]),
+                    )
+                ),
+            ]
+        )
+        self.convs1.apply(init_weights)
+
+        self.convs2 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+            ]
+        )
+        self.convs2.apply(init_weights)
+
+        self.num_layers = len(self.convs1) + len(
+            self.convs2
+        )  # total number of conv layers
+
+        if (
+            activation == "snake"
+        ):  # periodic nonlinearity with snake function and anti-aliasing
+            self.activations = nn.ModuleList(
+                [
+                    Activation1d(
+                        activation=Snake(channels, alpha_logscale=h.snake_logscale)
+                    )
+                    for _ in range(self.num_layers)
+                ]
+            )
+        elif (
+            activation == "snakebeta"
+        ):  # periodic nonlinearity with snakebeta function and anti-aliasing
+            self.activations = nn.ModuleList(
+                [
+                    Activation1d(
+                        activation=SnakeBeta(channels, alpha_logscale=h.snake_logscale)
+                    )
+                    for _ in range(self.num_layers)
+                ]
+            )
+        else:
+            raise NotImplementedError(
+                "activation incorrectly specified. check the config file and look for 'activation'."
+            )
+
+    def forward(self, x):
+        acts1, acts2 = self.activations[::2], self.activations[1::2]
+        for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
+            xt = a1(x)
+            xt = c1(xt)
+            xt = a2(xt)
+            xt = c2(xt)
+            x = xt + x
+
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs1:
+            remove_weight_norm(l)
+        for l in self.convs2:
+            remove_weight_norm(l)
+
+
+class AMPBlock2(torch.nn.Module):
+    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
+        super(AMPBlock2, self).__init__()
+        self.h = h
+
+        self.convs = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+            ]
+        )
+        self.convs.apply(init_weights)
+
+        self.num_layers = len(self.convs)  # total number of conv layers
+
+        if (
+            activation == "snake"
+        ):  # periodic nonlinearity with snake function and anti-aliasing
+            self.activations = nn.ModuleList(
+                [
+                    Activation1d(
+                        activation=Snake(channels, alpha_logscale=h.snake_logscale)
+                    )
+                    for _ in range(self.num_layers)
+                ]
+            )
+        elif (
+            activation == "snakebeta"
+        ):  # periodic nonlinearity with snakebeta function and anti-aliasing
+            self.activations = nn.ModuleList(
+                [
+                    Activation1d(
+                        activation=SnakeBeta(channels, alpha_logscale=h.snake_logscale)
+                    )
+                    for _ in range(self.num_layers)
+                ]
+            )
+        else:
+            raise NotImplementedError(
+                "activation incorrectly specified. check the config file and look for 'activation'."
+            )
+
+    def forward(self, x):
+        for c, a in zip(self.convs, self.activations):
+            xt = a(x)
+            xt = c(xt)
+            x = xt + x
+
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs:
+            remove_weight_norm(l)
+
+
+class BigVGANModule(torch.nn.Module):
+    # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
+    def __init__(self, h):
+        super(BigVGANModule, self).__init__()
+        self.h = h
+
+        self.num_kernels = len(h.resblock_kernel_sizes)
+        self.num_upsamples = len(h.upsample_rates)
+
+        # pre conv
+        self.conv_pre = weight_norm(
+            Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
+        )
+
+        # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
+        resblock = AMPBlock1 if h.resblock == "1" else AMPBlock2
+
+        # transposed conv-based upsamplers. does not apply anti-aliasing
+        self.ups = nn.ModuleList()
+        for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
+            self.ups.append(
+                nn.ModuleList(
+                    [
+                        weight_norm(
+                            ConvTranspose1d(
+                                h.upsample_initial_channel // (2**i),
+                                h.upsample_initial_channel // (2 ** (i + 1)),
+                                k,
+                                u,
+                                padding=(k - u) // 2,
+                            )
+                        )
+                    ]
+                )
+            )
+
+        # residual blocks using anti-aliased multi-periodicity composition modules (AMP)
+        self.resblocks = nn.ModuleList()
+        for i in range(len(self.ups)):
+            ch = h.upsample_initial_channel // (2 ** (i + 1))
+            for j, (k, d) in enumerate(
+                zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
+            ):
+                self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
+
+        # post conv
+        if (
+            h.activation == "snake"
+        ):  # periodic nonlinearity with snake function and anti-aliasing
+            activation_post = Snake(ch, alpha_logscale=h.snake_logscale)
+            self.activation_post = Activation1d(activation=activation_post)
+        elif (
+            h.activation == "snakebeta"
+        ):  # periodic nonlinearity with snakebeta function and anti-aliasing
+            activation_post = SnakeBeta(ch, alpha_logscale=h.snake_logscale)
+            self.activation_post = Activation1d(activation=activation_post)
+        else:
+            raise NotImplementedError(
+                "activation incorrectly specified. check the config file and look for 'activation'."
+            )
+
+        self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
+
+        # weight initialization
+        for i in range(len(self.ups)):
+            self.ups[i].apply(init_weights)
+        self.conv_post.apply(init_weights)
+
+    def forward(self, x):
+        # pre conv
+        x = self.conv_pre(x)
+
+        for i in range(self.num_upsamples):
+            # upsampling
+            for i_up in range(len(self.ups[i])):
+                x = self.ups[i][i_up](x)
+            # AMP blocks
+            xs = None
+            for j in range(self.num_kernels):
+                if xs is None:
+                    xs = self.resblocks[i * self.num_kernels + j](x)
+                else:
+                    xs += self.resblocks[i * self.num_kernels + j](x)
+            x = xs / self.num_kernels
+
+        # post conv
+        x = self.activation_post(x)
+        x = self.conv_post(x)
+        x = torch.tanh(x)
+
+        return x
+
+    def remove_weight_norm(self):
+        print("Removing weight norm...")
+        for l in self.ups:
+            for l_i in l:
+                remove_weight_norm(l_i)
+        for l in self.resblocks:
+            l.remove_weight_norm()
+        remove_weight_norm(self.conv_pre)
+        remove_weight_norm(self.conv_post)
+
+
+class BigVGAN(nn.Module):
+    def __init__(
+        self,
+        checkpoint_path: str = "checkpoints/bigvgan-24k-100band/g_05000000",
+        config_file: Optional[str] = None,
+    ):
+        super().__init__()
+
+        if config_file is None:
+            config_file = Path(checkpoint_path).parent / "config.json"
+
+        with open(config_file) as f:
+            data = f.read()
+
+        json_config = json.loads(data)
+        self.h = AttrDict(json_config)
+        self.model = BigVGANModule(self.h)
+
+        state_dict = torch.load(checkpoint_path, map_location="cpu")["generator"]
+        self.model.load_state_dict(state_dict, strict=True)
+        self.model.eval()
+        self.model.remove_weight_norm()
+
+        self.mel_transform = LogMelSpectrogram(
+            sample_rate=self.h.sampling_rate,
+            n_fft=self.h.n_fft,
+            win_length=self.h.win_size,
+            hop_length=self.h.hop_size,
+            f_min=self.h.fmin,
+            f_max=self.h.fmax,
+            n_mels=self.h.num_mels,
+        )
+
+    @torch.no_grad()
+    def decode(self, mel):
+        y = self.model(mel)
+        return y
+
+    @torch.no_grad()
+    def encode(self, x):
+        return self.mel_transform(x)
+
+
+if __name__ == "__main__":
+    import librosa
+    import soundfile as sf
+
+    x = "data/StarRail/Chinese/罗刹/archive_luocha_2.wav"
+    model = BigVGAN()
+
+    wav, sr = librosa.load(x, sr=24000, mono=True)
+    wav = torch.from_numpy(wav).float()[None]
+    mel = model.encode(wav)
+
+    wav = model.decode(mel)[0].mT
+    sf.write("test.wav", wav.cpu().numpy(), 24000)

+ 80 - 0
fish_speech/models/vq_diffusion/bigvgan/utils.py

@@ -0,0 +1,80 @@
+# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
+#   LICENSE is in incl_licenses directory.
+
+import glob
+import os
+
+import matplotlib
+import torch
+from torch.nn.utils import weight_norm
+
+matplotlib.use("Agg")
+import matplotlib.pylab as plt
+from scipy.io.wavfile import write
+
+
+def plot_spectrogram(spectrogram):
+    fig, ax = plt.subplots(figsize=(10, 2))
+    im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
+    plt.colorbar(im, ax=ax)
+
+    fig.canvas.draw()
+    plt.close()
+
+    return fig
+
+
+def plot_spectrogram_clipped(spectrogram, clip_max=2.0):
+    fig, ax = plt.subplots(figsize=(10, 2))
+    im = ax.imshow(
+        spectrogram,
+        aspect="auto",
+        origin="lower",
+        interpolation="none",
+        vmin=1e-6,
+        vmax=clip_max,
+    )
+    plt.colorbar(im, ax=ax)
+
+    fig.canvas.draw()
+    plt.close()
+
+    return fig
+
+
+def init_weights(m, mean=0.0, std=0.01):
+    classname = m.__class__.__name__
+    if classname.find("Conv") != -1:
+        m.weight.data.normal_(mean, std)
+
+
+def apply_weight_norm(m):
+    classname = m.__class__.__name__
+    if classname.find("Conv") != -1:
+        weight_norm(m)
+
+
+def get_padding(kernel_size, dilation=1):
+    return int((kernel_size * dilation - dilation) / 2)
+
+
+def load_checkpoint(filepath, device):
+    assert os.path.isfile(filepath)
+    print("Loading '{}'".format(filepath))
+    checkpoint_dict = torch.load(filepath, map_location=device)
+    print("Complete.")
+    return checkpoint_dict
+
+
+def save_checkpoint(filepath, obj):
+    print("Saving checkpoint to {}".format(filepath))
+    torch.save(obj, filepath)
+    print("Complete.")
+
+
+def scan_checkpoint(cp_dir, prefix):
+    pattern = os.path.join(cp_dir, prefix + "????????")
+    cp_list = glob.glob(pattern)
+    if len(cp_list) == 0:
+        return None
+    return sorted(cp_list)[-1]