|
@@ -3,9 +3,7 @@ import typing as tp
|
|
|
from dataclasses import dataclass
|
|
from dataclasses import dataclass
|
|
|
from typing import List, Optional, Union
|
|
from typing import List, Optional, Union
|
|
|
|
|
|
|
|
-import numpy as np
|
|
|
|
|
import torch
|
|
import torch
|
|
|
-from audiotools import AudioSignal
|
|
|
|
|
from audiotools.ml import BaseModel
|
|
from audiotools.ml import BaseModel
|
|
|
from dac.model.base import CodecMixin
|
|
from dac.model.base import CodecMixin
|
|
|
from dac.nn.layers import Snake1d, WNConv1d, WNConvTranspose1d
|
|
from dac.nn.layers import Snake1d, WNConv1d, WNConvTranspose1d
|
|
@@ -64,7 +62,7 @@ class ModelArgs:
|
|
|
|
|
|
|
|
class KVCache(nn.Module):
|
|
class KVCache(nn.Module):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
- self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16
|
|
|
|
|
|
|
+ self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16
|
|
|
):
|
|
):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
|
|
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
|
|
@@ -143,14 +141,14 @@ class Transformer(nn.Module):
|
|
|
self.use_kv_cache = True
|
|
self.use_kv_cache = True
|
|
|
|
|
|
|
|
def forward(
|
|
def forward(
|
|
|
- self,
|
|
|
|
|
- x: Tensor,
|
|
|
|
|
- input_pos: Optional[Tensor] = None,
|
|
|
|
|
- mask: Optional[Tensor] = None,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ x: Tensor,
|
|
|
|
|
+ input_pos: Optional[Tensor] = None,
|
|
|
|
|
+ mask: Optional[Tensor] = None,
|
|
|
) -> Tensor:
|
|
) -> Tensor:
|
|
|
if self.config.pos_embed_type == "rope":
|
|
if self.config.pos_embed_type == "rope":
|
|
|
assert (
|
|
assert (
|
|
|
- self.freqs_cis is not None
|
|
|
|
|
|
|
+ self.freqs_cis is not None
|
|
|
), "RoPE frequencies must be initialized for RoPE positional embedding"
|
|
), "RoPE frequencies must be initialized for RoPE positional embedding"
|
|
|
# print("MAX", input_pos.max())
|
|
# print("MAX", input_pos.max())
|
|
|
freqs_cis = self.freqs_cis[input_pos]
|
|
freqs_cis = self.freqs_cis[input_pos]
|
|
@@ -182,11 +180,11 @@ class TransformerBlock(nn.Module):
|
|
|
self.ffn_layer_scale = LayerScale(config.dim, inplace=True)
|
|
self.ffn_layer_scale = LayerScale(config.dim, inplace=True)
|
|
|
|
|
|
|
|
def forward(
|
|
def forward(
|
|
|
- self,
|
|
|
|
|
- x: Tensor,
|
|
|
|
|
- input_pos: Tensor,
|
|
|
|
|
- freqs_cis: Tensor,
|
|
|
|
|
- mask: Tensor,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ x: Tensor,
|
|
|
|
|
+ input_pos: Tensor,
|
|
|
|
|
+ freqs_cis: Tensor,
|
|
|
|
|
+ mask: Tensor,
|
|
|
) -> Tensor:
|
|
) -> Tensor:
|
|
|
h = x + self.attention_layer_scale(
|
|
h = x + self.attention_layer_scale(
|
|
|
self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
|
|
self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
|
|
@@ -241,11 +239,11 @@ class Attention(nn.Module):
|
|
|
return rel_logits
|
|
return rel_logits
|
|
|
|
|
|
|
|
def forward(
|
|
def forward(
|
|
|
- self,
|
|
|
|
|
- x: Tensor,
|
|
|
|
|
- freqs_cis: Tensor,
|
|
|
|
|
- mask: Tensor,
|
|
|
|
|
- input_pos: Optional[Tensor] = None,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ x: Tensor,
|
|
|
|
|
+ freqs_cis: Tensor,
|
|
|
|
|
+ mask: Tensor,
|
|
|
|
|
+ input_pos: Optional[Tensor] = None,
|
|
|
) -> Tensor:
|
|
) -> Tensor:
|
|
|
bsz, seqlen, _ = x.shape
|
|
bsz, seqlen, _ = x.shape
|
|
|
|
|
|
|
@@ -259,15 +257,46 @@ class Attention(nn.Module):
|
|
|
k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
|
k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
|
|
v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
|
v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
|
|
|
|
|
|
|
- if self.pos_embed_type == "rope":
|
|
|
|
|
- q = apply_rotary_emb(q, freqs_cis)
|
|
|
|
|
- k = apply_rotary_emb(k, freqs_cis)
|
|
|
|
|
|
|
+ # if self.pos_embed_type == "rope":
|
|
|
|
|
+ # q = apply_rotary_emb(q, freqs_cis)
|
|
|
|
|
+ # k = apply_rotary_emb(k, freqs_cis)
|
|
|
|
|
|
|
|
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
|
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
|
|
|
|
|
|
|
if self.kv_cache is not None:
|
|
if self.kv_cache is not None:
|
|
|
k, v = self.kv_cache.update(input_pos, k, v)
|
|
k, v = self.kv_cache.update(input_pos, k, v)
|
|
|
|
|
|
|
|
|
|
+ if input_pos is not None:
|
|
|
|
|
+ # =========================
|
|
|
|
|
+ # 🔥 KV cache window 裁剪(核心优化)
|
|
|
|
|
+ # =========================
|
|
|
|
|
+ max_context = 4096 # ⭐ 推荐 4K 或 8K
|
|
|
|
|
+
|
|
|
|
|
+ # 当前有效长度
|
|
|
|
|
+ seq_len = int(input_pos.max().item()) + 1
|
|
|
|
|
+
|
|
|
|
|
+ # window 起点
|
|
|
|
|
+ start = max(0, seq_len - max_context)
|
|
|
|
|
+
|
|
|
|
|
+ # 裁剪 KV
|
|
|
|
|
+ k = k[:, :, start:seq_len, :]
|
|
|
|
|
+ v = v[:, :, start:seq_len, :]
|
|
|
|
|
+
|
|
|
|
|
+ # =========================
|
|
|
|
|
+ # 🔥 同步裁剪 mask(如果有)
|
|
|
|
|
+ # =========================
|
|
|
|
|
+ if mask is not None:
|
|
|
|
|
+ mask = mask[:, :, :, start:seq_len]
|
|
|
|
|
+
|
|
|
|
|
+ # =========================
|
|
|
|
|
+ # 🔥 同步裁剪 RoPE(关键,不然会炸)
|
|
|
|
|
+ # =========================
|
|
|
|
|
+ freqs_cis = freqs_cis[start:seq_len]
|
|
|
|
|
+
|
|
|
|
|
+ if self.pos_embed_type == "rope":
|
|
|
|
|
+ q = apply_rotary_emb(q, freqs_cis)
|
|
|
|
|
+ k = apply_rotary_emb(k, freqs_cis)
|
|
|
|
|
+
|
|
|
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
|
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
|
|
|
|
|
@@ -335,10 +364,10 @@ class RMSNorm(nn.Module):
|
|
|
|
|
|
|
|
class LayerScale(nn.Module):
|
|
class LayerScale(nn.Module):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
- self,
|
|
|
|
|
- dim: int,
|
|
|
|
|
- init_values: Union[float, Tensor] = 1e-2,
|
|
|
|
|
- inplace: bool = False,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ dim: int,
|
|
|
|
|
+ init_values: Union[float, Tensor] = 1e-2,
|
|
|
|
|
+ inplace: bool = False,
|
|
|
) -> None:
|
|
) -> None:
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
self.inplace = inplace
|
|
self.inplace = inplace
|
|
@@ -354,12 +383,12 @@ class WindowLimitedTransformer(Transformer):
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
def __init__(
|
|
|
- self,
|
|
|
|
|
- config: ModelArgs,
|
|
|
|
|
- input_dim: int = 512,
|
|
|
|
|
- window_size: Optional[int] = None,
|
|
|
|
|
- causal: bool = True,
|
|
|
|
|
- look_ahead_conv: nn.Module = None,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ config: ModelArgs,
|
|
|
|
|
+ input_dim: int = 512,
|
|
|
|
|
+ window_size: Optional[int] = None,
|
|
|
|
|
+ causal: bool = True,
|
|
|
|
|
+ look_ahead_conv: nn.Module = None,
|
|
|
):
|
|
):
|
|
|
super().__init__(config)
|
|
super().__init__(config)
|
|
|
self.window_size = window_size
|
|
self.window_size = window_size
|
|
@@ -380,9 +409,9 @@ class WindowLimitedTransformer(Transformer):
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
def make_window_limited_mask(
|
|
def make_window_limited_mask(
|
|
|
- self,
|
|
|
|
|
- max_length: int,
|
|
|
|
|
- x_lens: Optional[Tensor] = None,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ max_length: int,
|
|
|
|
|
+ x_lens: Optional[Tensor] = None,
|
|
|
) -> Tensor:
|
|
) -> Tensor:
|
|
|
"""
|
|
"""
|
|
|
Make mask to form window limited attention.
|
|
Make mask to form window limited attention.
|
|
@@ -400,9 +429,9 @@ class WindowLimitedTransformer(Transformer):
|
|
|
return mask
|
|
return mask
|
|
|
|
|
|
|
|
def make_mask(
|
|
def make_mask(
|
|
|
- self,
|
|
|
|
|
- max_length: int,
|
|
|
|
|
- x_lens: Optional[Tensor] = None,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ max_length: int,
|
|
|
|
|
+ x_lens: Optional[Tensor] = None,
|
|
|
) -> Tensor:
|
|
) -> Tensor:
|
|
|
"""
|
|
"""
|
|
|
Make ordinary mask if window size is not specified.
|
|
Make ordinary mask if window size is not specified.
|
|
@@ -418,9 +447,9 @@ class WindowLimitedTransformer(Transformer):
|
|
|
return mask
|
|
return mask
|
|
|
|
|
|
|
|
def forward(
|
|
def forward(
|
|
|
- self,
|
|
|
|
|
- x: Tensor,
|
|
|
|
|
- x_lens: Optional[Tensor] = None,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ x: Tensor,
|
|
|
|
|
+ x_lens: Optional[Tensor] = None,
|
|
|
) -> Tensor:
|
|
) -> Tensor:
|
|
|
if self.channels_first:
|
|
if self.channels_first:
|
|
|
x = x.transpose(1, 2)
|
|
x = x.transpose(1, 2)
|
|
@@ -442,10 +471,10 @@ class WindowLimitedTransformer(Transformer):
|
|
|
|
|
|
|
|
|
|
|
|
|
def precompute_freqs_cis(
|
|
def precompute_freqs_cis(
|
|
|
- seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16
|
|
|
|
|
|
|
+ seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16
|
|
|
) -> Tensor:
|
|
) -> Tensor:
|
|
|
freqs = 1.0 / (
|
|
freqs = 1.0 / (
|
|
|
- base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
|
|
|
|
|
|
|
+ base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
|
|
|
)
|
|
)
|
|
|
t = torch.arange(seq_len, device=freqs.device)
|
|
t = torch.arange(seq_len, device=freqs.device)
|
|
|
freqs = torch.outer(t, freqs)
|
|
freqs = torch.outer(t, freqs)
|
|
@@ -485,7 +514,7 @@ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_extra_padding_for_conv1d(
|
|
def get_extra_padding_for_conv1d(
|
|
|
- x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
|
|
|
|
|
|
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
|
|
) -> int:
|
|
) -> int:
|
|
|
"""See `pad_for_conv1d`."""
|
|
"""See `pad_for_conv1d`."""
|
|
|
length = x.shape[-1]
|
|
length = x.shape[-1]
|
|
@@ -495,10 +524,10 @@ def get_extra_padding_for_conv1d(
|
|
|
|
|
|
|
|
|
|
|
|
|
def pad1d(
|
|
def pad1d(
|
|
|
- x: torch.Tensor,
|
|
|
|
|
- paddings: tp.Tuple[int, int],
|
|
|
|
|
- mode: str = "zeros",
|
|
|
|
|
- value: float = 0.0,
|
|
|
|
|
|
|
+ x: torch.Tensor,
|
|
|
|
|
+ paddings: tp.Tuple[int, int],
|
|
|
|
|
+ mode: str = "zeros",
|
|
|
|
|
+ value: float = 0.0,
|
|
|
):
|
|
):
|
|
|
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
|
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
|
|
If this is the case, we insert extra 0 padding to the right
|
|
If this is the case, we insert extra 0 padding to the right
|
|
@@ -522,14 +551,14 @@ def pad1d(
|
|
|
|
|
|
|
|
class CausalConvNet(nn.Module):
|
|
class CausalConvNet(nn.Module):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
- self,
|
|
|
|
|
- in_channels,
|
|
|
|
|
- out_channels,
|
|
|
|
|
- kernel_size,
|
|
|
|
|
- dilation=1,
|
|
|
|
|
- stride=1,
|
|
|
|
|
- groups=1,
|
|
|
|
|
- padding=None,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ in_channels,
|
|
|
|
|
+ out_channels,
|
|
|
|
|
+ kernel_size,
|
|
|
|
|
+ dilation=1,
|
|
|
|
|
+ stride=1,
|
|
|
|
|
+ groups=1,
|
|
|
|
|
+ padding=None,
|
|
|
):
|
|
):
|
|
|
super(CausalConvNet, self).__init__()
|
|
super(CausalConvNet, self).__init__()
|
|
|
self.conv = nn.Conv1d(
|
|
self.conv = nn.Conv1d(
|
|
@@ -564,7 +593,7 @@ class CausalConvNet(nn.Module):
|
|
|
|
|
|
|
|
class CausalTransConvNet(nn.Module):
|
|
class CausalTransConvNet(nn.Module):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
- self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
|
|
|
|
|
|
|
+ self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
|
|
|
):
|
|
):
|
|
|
super(CausalTransConvNet, self).__init__()
|
|
super(CausalTransConvNet, self).__init__()
|
|
|
self.conv = nn.ConvTranspose1d(
|
|
self.conv = nn.ConvTranspose1d(
|
|
@@ -618,18 +647,18 @@ class ResidualUnit(nn.Module):
|
|
|
if self.causal:
|
|
if self.causal:
|
|
|
x = x[..., :-pad]
|
|
x = x[..., :-pad]
|
|
|
else:
|
|
else:
|
|
|
- x = x[..., pad // 2 : -pad // 2]
|
|
|
|
|
|
|
+ x = x[..., pad // 2: -pad // 2]
|
|
|
return x + y
|
|
return x + y
|
|
|
|
|
|
|
|
|
|
|
|
|
class EncoderBlock(nn.Module):
|
|
class EncoderBlock(nn.Module):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
- self,
|
|
|
|
|
- dim: int = 16,
|
|
|
|
|
- stride: int = 1,
|
|
|
|
|
- causal: bool = False,
|
|
|
|
|
- n_t_layer: int = 0,
|
|
|
|
|
- transformer_general_config=None,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ dim: int = 16,
|
|
|
|
|
+ stride: int = 1,
|
|
|
|
|
+ causal: bool = False,
|
|
|
|
|
+ n_t_layer: int = 0,
|
|
|
|
|
+ transformer_general_config=None,
|
|
|
):
|
|
):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
conv_class = CausalWNConv1d if causal else WNConv1d
|
|
conv_class = CausalWNConv1d if causal else WNConv1d
|
|
@@ -671,13 +700,13 @@ class EncoderBlock(nn.Module):
|
|
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
class Encoder(nn.Module):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
- self,
|
|
|
|
|
- d_model: int = 64,
|
|
|
|
|
- strides: list = [2, 4, 8, 8],
|
|
|
|
|
- d_latent: int = 64,
|
|
|
|
|
- n_transformer_layers: list = [0, 0, 4, 4],
|
|
|
|
|
- transformer_general_config: ModelArgs = None,
|
|
|
|
|
- causal: bool = False,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ d_model: int = 64,
|
|
|
|
|
+ strides: list = [2, 4, 8, 8],
|
|
|
|
|
+ d_latent: int = 64,
|
|
|
|
|
+ n_transformer_layers: list = [0, 0, 4, 4],
|
|
|
|
|
+ transformer_general_config: ModelArgs = None,
|
|
|
|
|
+ causal: bool = False,
|
|
|
):
|
|
):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
conv_class = CausalWNConv1d if causal else WNConv1d
|
|
conv_class = CausalWNConv1d if causal else WNConv1d
|
|
@@ -713,13 +742,13 @@ class Encoder(nn.Module):
|
|
|
|
|
|
|
|
class DecoderBlock(nn.Module):
|
|
class DecoderBlock(nn.Module):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
- self,
|
|
|
|
|
- input_dim: int = 16,
|
|
|
|
|
- output_dim: int = 8,
|
|
|
|
|
- stride: int = 1,
|
|
|
|
|
- causal: bool = False,
|
|
|
|
|
- n_t_layer: int = 0,
|
|
|
|
|
- transformer_general_config=None,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ input_dim: int = 16,
|
|
|
|
|
+ output_dim: int = 8,
|
|
|
|
|
+ stride: int = 1,
|
|
|
|
|
+ causal: bool = False,
|
|
|
|
|
+ n_t_layer: int = 0,
|
|
|
|
|
+ transformer_general_config=None,
|
|
|
):
|
|
):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d
|
|
conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d
|
|
@@ -761,14 +790,14 @@ class DecoderBlock(nn.Module):
|
|
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
class Decoder(nn.Module):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
- self,
|
|
|
|
|
- input_channel,
|
|
|
|
|
- channels,
|
|
|
|
|
- rates,
|
|
|
|
|
- d_out: int = 1,
|
|
|
|
|
- causal: bool = False,
|
|
|
|
|
- n_transformer_layers: list = [0, 0, 0, 0],
|
|
|
|
|
- transformer_general_config=None,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ input_channel,
|
|
|
|
|
+ channels,
|
|
|
|
|
+ rates,
|
|
|
|
|
+ d_out: int = 1,
|
|
|
|
|
+ causal: bool = False,
|
|
|
|
|
+ n_transformer_layers: list = [0, 0, 0, 0],
|
|
|
|
|
+ transformer_general_config=None,
|
|
|
):
|
|
):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
conv_class = CausalWNConv1d if causal else WNConv1d
|
|
conv_class = CausalWNConv1d if causal else WNConv1d
|
|
@@ -777,7 +806,7 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
# Add upsampling + MRF blocks
|
|
# Add upsampling + MRF blocks
|
|
|
for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)):
|
|
for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)):
|
|
|
- input_dim = channels // 2**i
|
|
|
|
|
|
|
+ input_dim = channels // 2 ** i
|
|
|
output_dim = channels // 2 ** (i + 1)
|
|
output_dim = channels // 2 ** (i + 1)
|
|
|
layers += [
|
|
layers += [
|
|
|
DecoderBlock(
|
|
DecoderBlock(
|
|
@@ -805,19 +834,19 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
class DAC(BaseModel, CodecMixin):
|
|
class DAC(BaseModel, CodecMixin):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
- self,
|
|
|
|
|
- encoder_dim: int = 64,
|
|
|
|
|
- encoder_rates: List[int] = [2, 4, 8, 8],
|
|
|
|
|
- latent_dim: int = None,
|
|
|
|
|
- decoder_dim: int = 1536,
|
|
|
|
|
- decoder_rates: List[int] = [8, 8, 4, 2],
|
|
|
|
|
- quantizer: torch.nn.Module = None,
|
|
|
|
|
- sample_rate: int = 44100,
|
|
|
|
|
- causal: bool = True,
|
|
|
|
|
- encoder_transformer_layers: List[int] = [0, 0, 0, 0],
|
|
|
|
|
- decoder_transformer_layers: List[int] = [0, 0, 0, 0],
|
|
|
|
|
- overwrite_decoder: torch.nn.Module = None,
|
|
|
|
|
- transformer_general_config=None,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ encoder_dim: int = 64,
|
|
|
|
|
+ encoder_rates: List[int] = [2, 4, 8, 8],
|
|
|
|
|
+ latent_dim: int = None,
|
|
|
|
|
+ decoder_dim: int = 1536,
|
|
|
|
|
+ decoder_rates: List[int] = [8, 8, 4, 2],
|
|
|
|
|
+ quantizer: torch.nn.Module = None,
|
|
|
|
|
+ sample_rate: int = 44100,
|
|
|
|
|
+ causal: bool = True,
|
|
|
|
|
+ encoder_transformer_layers: List[int] = [0, 0, 0, 0],
|
|
|
|
|
+ decoder_transformer_layers: List[int] = [0, 0, 0, 0],
|
|
|
|
|
+ overwrite_decoder: torch.nn.Module = None,
|
|
|
|
|
+ transformer_general_config=None,
|
|
|
):
|
|
):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
|
|
|
|
@@ -874,11 +903,11 @@ class DAC(BaseModel, CodecMixin):
|
|
|
return audio_data
|
|
return audio_data
|
|
|
|
|
|
|
|
def encode(
|
|
def encode(
|
|
|
- self,
|
|
|
|
|
- audio_data: torch.Tensor,
|
|
|
|
|
- audio_lengths: torch.Tensor = None,
|
|
|
|
|
- n_quantizers: int = None,
|
|
|
|
|
- **kwargs,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ audio_data: torch.Tensor,
|
|
|
|
|
+ audio_lengths: torch.Tensor = None,
|
|
|
|
|
+ n_quantizers: int = None,
|
|
|
|
|
+ **kwargs,
|
|
|
):
|
|
):
|
|
|
"""Encode given audio data and return quantized latent codes
|
|
"""Encode given audio data and return quantized latent codes
|
|
|
|
|
|
|
@@ -948,13 +977,13 @@ class DAC(BaseModel, CodecMixin):
|
|
|
return self.decoder(z)
|
|
return self.decoder(z)
|
|
|
|
|
|
|
|
def forward(
|
|
def forward(
|
|
|
- self,
|
|
|
|
|
- audio_data: torch.Tensor,
|
|
|
|
|
- template: torch.Tensor = None,
|
|
|
|
|
- mask: torch.Tensor = None,
|
|
|
|
|
- sample_rate: int = None,
|
|
|
|
|
- n_quantizers: int = None,
|
|
|
|
|
- **kwargs,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ audio_data: torch.Tensor,
|
|
|
|
|
+ template: torch.Tensor = None,
|
|
|
|
|
+ mask: torch.Tensor = None,
|
|
|
|
|
+ sample_rate: int = None,
|
|
|
|
|
+ n_quantizers: int = None,
|
|
|
|
|
+ **kwargs,
|
|
|
):
|
|
):
|
|
|
"""Model forward pass
|
|
"""Model forward pass
|
|
|
|
|
|