| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225 |
- import math
- from typing import Optional
- import torch
- import torch.nn.functional as F
- from torch import nn
- class Mish(nn.Module):
- def forward(self, x):
- return x * torch.tanh(F.softplus(x))
- class DiffusionEmbedding(nn.Module):
- """Diffusion Step Embedding"""
- def __init__(self, d_denoiser):
- super(DiffusionEmbedding, self).__init__()
- self.dim = d_denoiser
- def forward(self, x):
- device = x.device
- half_dim = self.dim // 2
- emb = math.log(10000) / (half_dim - 1)
- emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
- emb = x[:, None] * emb[None, :]
- emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
- return emb
- class LinearNorm(nn.Module):
- """LinearNorm Projection"""
- def __init__(self, in_features, out_features, bias=False):
- super(LinearNorm, self).__init__()
- self.linear = nn.Linear(in_features, out_features, bias)
- nn.init.xavier_uniform_(self.linear.weight)
- if bias:
- nn.init.constant_(self.linear.bias, 0.0)
- def forward(self, x):
- x = self.linear(x)
- return x
- class ConvNorm(nn.Module):
- """1D Convolution"""
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size=1,
- stride=1,
- padding=None,
- dilation=1,
- bias=True,
- w_init_gain="linear",
- ):
- super(ConvNorm, self).__init__()
- if padding is None:
- assert kernel_size % 2 == 1
- padding = int(dilation * (kernel_size - 1) / 2)
- self.conv = nn.Conv1d(
- in_channels,
- out_channels,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- bias=bias,
- )
- nn.init.kaiming_normal_(self.conv.weight)
- def forward(self, signal):
- conv_signal = self.conv(signal)
- return conv_signal
- class ResidualBlock(nn.Module):
- """Residual Block"""
- def __init__(
- self,
- residual_channels,
- use_linear_bias=False,
- dilation=1,
- condition_channels=None,
- ):
- super(ResidualBlock, self).__init__()
- self.conv_layer = ConvNorm(
- residual_channels,
- 2 * residual_channels,
- kernel_size=3,
- stride=1,
- padding=dilation,
- dilation=dilation,
- )
- if condition_channels is not None:
- self.diffusion_projection = LinearNorm(
- residual_channels, residual_channels, use_linear_bias
- )
- self.condition_projection = ConvNorm(
- condition_channels, 2 * residual_channels, kernel_size=1
- )
- self.output_projection = ConvNorm(
- residual_channels, 2 * residual_channels, kernel_size=1
- )
- def forward(self, x, condition=None, diffusion_step=None):
- y = x
- if diffusion_step is not None:
- diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
- y = y + diffusion_step
- y = self.conv_layer(y)
- if condition is not None:
- condition = self.condition_projection(condition)
- y = y + condition
- gate, filter = torch.chunk(y, 2, dim=1)
- y = torch.sigmoid(gate) * torch.tanh(filter)
- y = self.output_projection(y)
- residual, skip = torch.chunk(y, 2, dim=1)
- return (x + residual) / math.sqrt(2.0), skip
- class WaveNet(nn.Module):
- def __init__(
- self,
- input_channels: Optional[int] = None,
- output_channels: Optional[int] = None,
- residual_channels: int = 512,
- residual_layers: int = 20,
- dilation_cycle: Optional[int] = 4,
- is_diffusion: bool = False,
- condition_channels: Optional[int] = None,
- ):
- super().__init__()
- # Input projection
- self.input_projection = None
- if input_channels is not None and input_channels != residual_channels:
- self.input_projection = ConvNorm(
- input_channels, residual_channels, kernel_size=1
- )
- if input_channels is None:
- input_channels = residual_channels
- self.input_channels = input_channels
- # Residual layers
- self.residual_layers = nn.ModuleList(
- [
- ResidualBlock(
- residual_channels=residual_channels,
- use_linear_bias=False,
- dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1,
- condition_channels=condition_channels,
- )
- for i in range(residual_layers)
- ]
- )
- # Skip projection
- self.skip_projection = ConvNorm(
- residual_channels, residual_channels, kernel_size=1
- )
- # Output projection
- self.output_projection = None
- if output_channels is not None and output_channels != residual_channels:
- self.output_projection = ConvNorm(
- residual_channels, output_channels, kernel_size=1
- )
- if is_diffusion:
- self.diffusion_embedding = DiffusionEmbedding(residual_channels)
- self.mlp = nn.Sequential(
- LinearNorm(residual_channels, residual_channels * 4, False),
- Mish(),
- LinearNorm(residual_channels * 4, residual_channels, False),
- )
- self.apply(self._init_weights)
- def _init_weights(self, m):
- if isinstance(m, (nn.Conv1d, nn.Linear)):
- nn.init.trunc_normal_(m.weight, std=0.02)
- if getattr(m, "bias", None) is not None:
- nn.init.constant_(m.bias, 0)
- def forward(self, x, t=None, condition=None):
- if self.input_projection is not None:
- x = self.input_projection(x)
- x = F.silu(x)
- if t is not None:
- t = self.diffusion_embedding(t)
- t = self.mlp(t)
- skip = []
- for layer in self.residual_layers:
- x, skip_connection = layer(x, condition, t)
- skip.append(skip_connection)
- x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
- x = self.skip_projection(x)
- if self.output_projection is not None:
- x = F.silu(x)
- x = self.output_projection(x)
- return x
|