|
@@ -143,14 +143,14 @@ class Transformer(nn.Module):
|
|
|
self.use_kv_cache = True
|
|
self.use_kv_cache = True
|
|
|
|
|
|
|
|
def forward(
|
|
def forward(
|
|
|
- self,
|
|
|
|
|
- x: Tensor,
|
|
|
|
|
- input_pos: Optional[Tensor] = None,
|
|
|
|
|
- mask: Optional[Tensor] = None,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ x: Tensor,
|
|
|
|
|
+ input_pos: Optional[Tensor] = None,
|
|
|
|
|
+ mask: Optional[Tensor] = None,
|
|
|
) -> Tensor:
|
|
) -> Tensor:
|
|
|
if self.config.pos_embed_type == "rope":
|
|
if self.config.pos_embed_type == "rope":
|
|
|
assert (
|
|
assert (
|
|
|
- self.freqs_cis is not None
|
|
|
|
|
|
|
+ self.freqs_cis is not None
|
|
|
), "RoPE frequencies must be initialized for RoPE positional embedding"
|
|
), "RoPE frequencies must be initialized for RoPE positional embedding"
|
|
|
# print("MAX", input_pos.max())
|
|
# print("MAX", input_pos.max())
|
|
|
freqs_cis = self.freqs_cis[input_pos]
|
|
freqs_cis = self.freqs_cis[input_pos]
|
|
@@ -182,11 +182,11 @@ class TransformerBlock(nn.Module):
|
|
|
self.ffn_layer_scale = LayerScale(config.dim, inplace=True)
|
|
self.ffn_layer_scale = LayerScale(config.dim, inplace=True)
|
|
|
|
|
|
|
|
def forward(
|
|
def forward(
|
|
|
- self,
|
|
|
|
|
- x: Tensor,
|
|
|
|
|
- input_pos: Tensor,
|
|
|
|
|
- freqs_cis: Tensor,
|
|
|
|
|
- mask: Tensor,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ x: Tensor,
|
|
|
|
|
+ input_pos: Tensor,
|
|
|
|
|
+ freqs_cis: Tensor,
|
|
|
|
|
+ mask: Tensor,
|
|
|
) -> Tensor:
|
|
) -> Tensor:
|
|
|
h = x + self.attention_layer_scale(
|
|
h = x + self.attention_layer_scale(
|
|
|
self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
|
|
self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
|
|
@@ -241,16 +241,14 @@ class Attention(nn.Module):
|
|
|
return rel_logits
|
|
return rel_logits
|
|
|
|
|
|
|
|
def forward(
|
|
def forward(
|
|
|
- self,
|
|
|
|
|
- x: Tensor,
|
|
|
|
|
- freqs_cis: Tensor,
|
|
|
|
|
- mask: Tensor,
|
|
|
|
|
- input_pos: Optional[Tensor] = None,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ x: Tensor,
|
|
|
|
|
+ freqs_cis: Tensor,
|
|
|
|
|
+ mask: Tensor,
|
|
|
|
|
+ input_pos: Optional[Tensor] = None,
|
|
|
) -> Tensor:
|
|
) -> Tensor:
|
|
|
bsz, seqlen, _ = x.shape
|
|
bsz, seqlen, _ = x.shape
|
|
|
|
|
|
|
|
- print(f"Attention forward self.n_local_heads {self.n_local_heads}, self.head_dim {self.head_dim}")
|
|
|
|
|
-
|
|
|
|
|
kv_size = self.n_local_heads * self.head_dim
|
|
kv_size = self.n_local_heads * self.head_dim
|
|
|
q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
|
|
q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
|
|
|
context_seqlen = seqlen
|
|
context_seqlen = seqlen
|
|
@@ -259,48 +257,15 @@ class Attention(nn.Module):
|
|
|
k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
|
k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
|
|
v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
|
v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
|
|
|
|
|
|
|
- # if self.pos_embed_type == "rope":
|
|
|
|
|
- # q = apply_rotary_emb(q, freqs_cis)
|
|
|
|
|
- # k = apply_rotary_emb(k, freqs_cis)
|
|
|
|
|
|
|
+ if self.pos_embed_type == "rope":
|
|
|
|
|
+ q = apply_rotary_emb(q, freqs_cis)
|
|
|
|
|
+ k = apply_rotary_emb(k, freqs_cis)
|
|
|
|
|
|
|
|
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
|
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
|
|
|
|
|
|
|
if self.kv_cache is not None:
|
|
if self.kv_cache is not None:
|
|
|
k, v = self.kv_cache.update(input_pos, k, v)
|
|
k, v = self.kv_cache.update(input_pos, k, v)
|
|
|
|
|
|
|
|
- if input_pos is not None:
|
|
|
|
|
- # =========================
|
|
|
|
|
- # 🔥 KV cache window 裁剪(核心优化)
|
|
|
|
|
- # =========================
|
|
|
|
|
- max_context = 4096 # ⭐ 推荐 4K 或 8K
|
|
|
|
|
-
|
|
|
|
|
- # 当前有效长度
|
|
|
|
|
- seq_len = int(input_pos.max().item()) + 1
|
|
|
|
|
-
|
|
|
|
|
- # window 起点
|
|
|
|
|
- start = max(0, seq_len - max_context)
|
|
|
|
|
-
|
|
|
|
|
- # 裁剪 KV
|
|
|
|
|
- k = k[:, :, start:seq_len, :]
|
|
|
|
|
- v = v[:, :, start:seq_len, :]
|
|
|
|
|
-
|
|
|
|
|
- # =========================
|
|
|
|
|
- # 🔥 同步裁剪 mask(如果有)
|
|
|
|
|
- # =========================
|
|
|
|
|
- if mask is not None:
|
|
|
|
|
- mask = mask[:, :, :, start:seq_len]
|
|
|
|
|
-
|
|
|
|
|
- # =========================
|
|
|
|
|
- # 🔥 同步裁剪 RoPE(关键,不然会炸)
|
|
|
|
|
- # =========================
|
|
|
|
|
- print(f"input_pos.dtype {input_pos.dtype}")
|
|
|
|
|
- assert input_pos.dtype == torch.long
|
|
|
|
|
- freqs_cis = torch.index_select(freqs_cis, 0, input_pos.long())
|
|
|
|
|
-
|
|
|
|
|
- if self.pos_embed_type == "rope":
|
|
|
|
|
- q = apply_rotary_emb(q, freqs_cis)
|
|
|
|
|
- k = apply_rotary_emb(k, freqs_cis)
|
|
|
|
|
-
|
|
|
|
|
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
|
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
|
|
|
|
|
@@ -368,10 +333,10 @@ class RMSNorm(nn.Module):
|
|
|
|
|
|
|
|
class LayerScale(nn.Module):
|
|
class LayerScale(nn.Module):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
- self,
|
|
|
|
|
- dim: int,
|
|
|
|
|
- init_values: Union[float, Tensor] = 1e-2,
|
|
|
|
|
- inplace: bool = False,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ dim: int,
|
|
|
|
|
+ init_values: Union[float, Tensor] = 1e-2,
|
|
|
|
|
+ inplace: bool = False,
|
|
|
) -> None:
|
|
) -> None:
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
self.inplace = inplace
|
|
self.inplace = inplace
|
|
@@ -387,12 +352,12 @@ class WindowLimitedTransformer(Transformer):
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
def __init__(
|
|
|
- self,
|
|
|
|
|
- config: ModelArgs,
|
|
|
|
|
- input_dim: int = 512,
|
|
|
|
|
- window_size: Optional[int] = None,
|
|
|
|
|
- causal: bool = True,
|
|
|
|
|
- look_ahead_conv: nn.Module = None,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ config: ModelArgs,
|
|
|
|
|
+ input_dim: int = 512,
|
|
|
|
|
+ window_size: Optional[int] = None,
|
|
|
|
|
+ causal: bool = True,
|
|
|
|
|
+ look_ahead_conv: nn.Module = None,
|
|
|
):
|
|
):
|
|
|
super().__init__(config)
|
|
super().__init__(config)
|
|
|
self.window_size = window_size
|
|
self.window_size = window_size
|
|
@@ -413,9 +378,9 @@ class WindowLimitedTransformer(Transformer):
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
def make_window_limited_mask(
|
|
def make_window_limited_mask(
|
|
|
- self,
|
|
|
|
|
- max_length: int,
|
|
|
|
|
- x_lens: Optional[Tensor] = None,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ max_length: int,
|
|
|
|
|
+ x_lens: Optional[Tensor] = None,
|
|
|
) -> Tensor:
|
|
) -> Tensor:
|
|
|
"""
|
|
"""
|
|
|
Make mask to form window limited attention.
|
|
Make mask to form window limited attention.
|
|
@@ -433,9 +398,9 @@ class WindowLimitedTransformer(Transformer):
|
|
|
return mask
|
|
return mask
|
|
|
|
|
|
|
|
def make_mask(
|
|
def make_mask(
|
|
|
- self,
|
|
|
|
|
- max_length: int,
|
|
|
|
|
- x_lens: Optional[Tensor] = None,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ max_length: int,
|
|
|
|
|
+ x_lens: Optional[Tensor] = None,
|
|
|
) -> Tensor:
|
|
) -> Tensor:
|
|
|
"""
|
|
"""
|
|
|
Make ordinary mask if window size is not specified.
|
|
Make ordinary mask if window size is not specified.
|
|
@@ -451,9 +416,9 @@ class WindowLimitedTransformer(Transformer):
|
|
|
return mask
|
|
return mask
|
|
|
|
|
|
|
|
def forward(
|
|
def forward(
|
|
|
- self,
|
|
|
|
|
- x: Tensor,
|
|
|
|
|
- x_lens: Optional[Tensor] = None,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ x: Tensor,
|
|
|
|
|
+ x_lens: Optional[Tensor] = None,
|
|
|
) -> Tensor:
|
|
) -> Tensor:
|
|
|
if self.channels_first:
|
|
if self.channels_first:
|
|
|
x = x.transpose(1, 2)
|
|
x = x.transpose(1, 2)
|
|
@@ -475,10 +440,10 @@ class WindowLimitedTransformer(Transformer):
|
|
|
|
|
|
|
|
|
|
|
|
|
def precompute_freqs_cis(
|
|
def precompute_freqs_cis(
|
|
|
- seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16
|
|
|
|
|
|
|
+ seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16
|
|
|
) -> Tensor:
|
|
) -> Tensor:
|
|
|
freqs = 1.0 / (
|
|
freqs = 1.0 / (
|
|
|
- base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
|
|
|
|
|
|
|
+ base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
|
|
|
)
|
|
)
|
|
|
t = torch.arange(seq_len, device=freqs.device)
|
|
t = torch.arange(seq_len, device=freqs.device)
|
|
|
freqs = torch.outer(t, freqs)
|
|
freqs = torch.outer(t, freqs)
|
|
@@ -518,7 +483,7 @@ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_extra_padding_for_conv1d(
|
|
def get_extra_padding_for_conv1d(
|
|
|
- x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
|
|
|
|
|
|
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
|
|
) -> int:
|
|
) -> int:
|
|
|
"""See `pad_for_conv1d`."""
|
|
"""See `pad_for_conv1d`."""
|
|
|
length = x.shape[-1]
|
|
length = x.shape[-1]
|
|
@@ -528,10 +493,10 @@ def get_extra_padding_for_conv1d(
|
|
|
|
|
|
|
|
|
|
|
|
|
def pad1d(
|
|
def pad1d(
|
|
|
- x: torch.Tensor,
|
|
|
|
|
- paddings: tp.Tuple[int, int],
|
|
|
|
|
- mode: str = "zeros",
|
|
|
|
|
- value: float = 0.0,
|
|
|
|
|
|
|
+ x: torch.Tensor,
|
|
|
|
|
+ paddings: tp.Tuple[int, int],
|
|
|
|
|
+ mode: str = "zeros",
|
|
|
|
|
+ value: float = 0.0,
|
|
|
):
|
|
):
|
|
|
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
|
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
|
|
If this is the case, we insert extra 0 padding to the right
|
|
If this is the case, we insert extra 0 padding to the right
|
|
@@ -555,14 +520,14 @@ def pad1d(
|
|
|
|
|
|
|
|
class CausalConvNet(nn.Module):
|
|
class CausalConvNet(nn.Module):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
- self,
|
|
|
|
|
- in_channels,
|
|
|
|
|
- out_channels,
|
|
|
|
|
- kernel_size,
|
|
|
|
|
- dilation=1,
|
|
|
|
|
- stride=1,
|
|
|
|
|
- groups=1,
|
|
|
|
|
- padding=None,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ in_channels,
|
|
|
|
|
+ out_channels,
|
|
|
|
|
+ kernel_size,
|
|
|
|
|
+ dilation=1,
|
|
|
|
|
+ stride=1,
|
|
|
|
|
+ groups=1,
|
|
|
|
|
+ padding=None,
|
|
|
):
|
|
):
|
|
|
super(CausalConvNet, self).__init__()
|
|
super(CausalConvNet, self).__init__()
|
|
|
self.conv = nn.Conv1d(
|
|
self.conv = nn.Conv1d(
|
|
@@ -597,7 +562,7 @@ class CausalConvNet(nn.Module):
|
|
|
|
|
|
|
|
class CausalTransConvNet(nn.Module):
|
|
class CausalTransConvNet(nn.Module):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
- self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
|
|
|
|
|
|
|
+ self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
|
|
|
):
|
|
):
|
|
|
super(CausalTransConvNet, self).__init__()
|
|
super(CausalTransConvNet, self).__init__()
|
|
|
self.conv = nn.ConvTranspose1d(
|
|
self.conv = nn.ConvTranspose1d(
|
|
@@ -651,18 +616,18 @@ class ResidualUnit(nn.Module):
|
|
|
if self.causal:
|
|
if self.causal:
|
|
|
x = x[..., :-pad]
|
|
x = x[..., :-pad]
|
|
|
else:
|
|
else:
|
|
|
- x = x[..., pad // 2: -pad // 2]
|
|
|
|
|
|
|
+ x = x[..., pad // 2 : -pad // 2]
|
|
|
return x + y
|
|
return x + y
|
|
|
|
|
|
|
|
|
|
|
|
|
class EncoderBlock(nn.Module):
|
|
class EncoderBlock(nn.Module):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
- self,
|
|
|
|
|
- dim: int = 16,
|
|
|
|
|
- stride: int = 1,
|
|
|
|
|
- causal: bool = False,
|
|
|
|
|
- n_t_layer: int = 0,
|
|
|
|
|
- transformer_general_config=None,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ dim: int = 16,
|
|
|
|
|
+ stride: int = 1,
|
|
|
|
|
+ causal: bool = False,
|
|
|
|
|
+ n_t_layer: int = 0,
|
|
|
|
|
+ transformer_general_config=None,
|
|
|
):
|
|
):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
conv_class = CausalWNConv1d if causal else WNConv1d
|
|
conv_class = CausalWNConv1d if causal else WNConv1d
|
|
@@ -704,13 +669,13 @@ class EncoderBlock(nn.Module):
|
|
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
class Encoder(nn.Module):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
- self,
|
|
|
|
|
- d_model: int = 64,
|
|
|
|
|
- strides: list = [2, 4, 8, 8],
|
|
|
|
|
- d_latent: int = 64,
|
|
|
|
|
- n_transformer_layers: list = [0, 0, 4, 4],
|
|
|
|
|
- transformer_general_config: ModelArgs = None,
|
|
|
|
|
- causal: bool = False,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ d_model: int = 64,
|
|
|
|
|
+ strides: list = [2, 4, 8, 8],
|
|
|
|
|
+ d_latent: int = 64,
|
|
|
|
|
+ n_transformer_layers: list = [0, 0, 4, 4],
|
|
|
|
|
+ transformer_general_config: ModelArgs = None,
|
|
|
|
|
+ causal: bool = False,
|
|
|
):
|
|
):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
conv_class = CausalWNConv1d if causal else WNConv1d
|
|
conv_class = CausalWNConv1d if causal else WNConv1d
|
|
@@ -746,13 +711,13 @@ class Encoder(nn.Module):
|
|
|
|
|
|
|
|
class DecoderBlock(nn.Module):
|
|
class DecoderBlock(nn.Module):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
- self,
|
|
|
|
|
- input_dim: int = 16,
|
|
|
|
|
- output_dim: int = 8,
|
|
|
|
|
- stride: int = 1,
|
|
|
|
|
- causal: bool = False,
|
|
|
|
|
- n_t_layer: int = 0,
|
|
|
|
|
- transformer_general_config=None,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ input_dim: int = 16,
|
|
|
|
|
+ output_dim: int = 8,
|
|
|
|
|
+ stride: int = 1,
|
|
|
|
|
+ causal: bool = False,
|
|
|
|
|
+ n_t_layer: int = 0,
|
|
|
|
|
+ transformer_general_config=None,
|
|
|
):
|
|
):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d
|
|
conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d
|
|
@@ -794,14 +759,14 @@ class DecoderBlock(nn.Module):
|
|
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
class Decoder(nn.Module):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
- self,
|
|
|
|
|
- input_channel,
|
|
|
|
|
- channels,
|
|
|
|
|
- rates,
|
|
|
|
|
- d_out: int = 1,
|
|
|
|
|
- causal: bool = False,
|
|
|
|
|
- n_transformer_layers: list = [0, 0, 0, 0],
|
|
|
|
|
- transformer_general_config=None,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ input_channel,
|
|
|
|
|
+ channels,
|
|
|
|
|
+ rates,
|
|
|
|
|
+ d_out: int = 1,
|
|
|
|
|
+ causal: bool = False,
|
|
|
|
|
+ n_transformer_layers: list = [0, 0, 0, 0],
|
|
|
|
|
+ transformer_general_config=None,
|
|
|
):
|
|
):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
conv_class = CausalWNConv1d if causal else WNConv1d
|
|
conv_class = CausalWNConv1d if causal else WNConv1d
|
|
@@ -810,7 +775,7 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
# Add upsampling + MRF blocks
|
|
# Add upsampling + MRF blocks
|
|
|
for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)):
|
|
for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)):
|
|
|
- input_dim = channels // 2 ** i
|
|
|
|
|
|
|
+ input_dim = channels // 2**i
|
|
|
output_dim = channels // 2 ** (i + 1)
|
|
output_dim = channels // 2 ** (i + 1)
|
|
|
layers += [
|
|
layers += [
|
|
|
DecoderBlock(
|
|
DecoderBlock(
|
|
@@ -838,19 +803,19 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
class DAC(BaseModel, CodecMixin):
|
|
class DAC(BaseModel, CodecMixin):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
- self,
|
|
|
|
|
- encoder_dim: int = 64,
|
|
|
|
|
- encoder_rates: List[int] = [2, 4, 8, 8],
|
|
|
|
|
- latent_dim: int = None,
|
|
|
|
|
- decoder_dim: int = 1536,
|
|
|
|
|
- decoder_rates: List[int] = [8, 8, 4, 2],
|
|
|
|
|
- quantizer: torch.nn.Module = None,
|
|
|
|
|
- sample_rate: int = 44100,
|
|
|
|
|
- causal: bool = True,
|
|
|
|
|
- encoder_transformer_layers: List[int] = [0, 0, 0, 0],
|
|
|
|
|
- decoder_transformer_layers: List[int] = [0, 0, 0, 0],
|
|
|
|
|
- overwrite_decoder: torch.nn.Module = None,
|
|
|
|
|
- transformer_general_config=None,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ encoder_dim: int = 64,
|
|
|
|
|
+ encoder_rates: List[int] = [2, 4, 8, 8],
|
|
|
|
|
+ latent_dim: int = None,
|
|
|
|
|
+ decoder_dim: int = 1536,
|
|
|
|
|
+ decoder_rates: List[int] = [8, 8, 4, 2],
|
|
|
|
|
+ quantizer: torch.nn.Module = None,
|
|
|
|
|
+ sample_rate: int = 44100,
|
|
|
|
|
+ causal: bool = True,
|
|
|
|
|
+ encoder_transformer_layers: List[int] = [0, 0, 0, 0],
|
|
|
|
|
+ decoder_transformer_layers: List[int] = [0, 0, 0, 0],
|
|
|
|
|
+ overwrite_decoder: torch.nn.Module = None,
|
|
|
|
|
+ transformer_general_config=None,
|
|
|
):
|
|
):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
|
|
|
|
@@ -907,11 +872,11 @@ class DAC(BaseModel, CodecMixin):
|
|
|
return audio_data
|
|
return audio_data
|
|
|
|
|
|
|
|
def encode(
|
|
def encode(
|
|
|
- self,
|
|
|
|
|
- audio_data: torch.Tensor,
|
|
|
|
|
- audio_lengths: torch.Tensor = None,
|
|
|
|
|
- n_quantizers: int = None,
|
|
|
|
|
- **kwargs,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ audio_data: torch.Tensor,
|
|
|
|
|
+ audio_lengths: torch.Tensor = None,
|
|
|
|
|
+ n_quantizers: int = None,
|
|
|
|
|
+ **kwargs,
|
|
|
):
|
|
):
|
|
|
"""Encode given audio data and return quantized latent codes
|
|
"""Encode given audio data and return quantized latent codes
|
|
|
|
|
|
|
@@ -981,13 +946,13 @@ class DAC(BaseModel, CodecMixin):
|
|
|
return self.decoder(z)
|
|
return self.decoder(z)
|
|
|
|
|
|
|
|
def forward(
|
|
def forward(
|
|
|
- self,
|
|
|
|
|
- audio_data: torch.Tensor,
|
|
|
|
|
- template: torch.Tensor = None,
|
|
|
|
|
- mask: torch.Tensor = None,
|
|
|
|
|
- sample_rate: int = None,
|
|
|
|
|
- n_quantizers: int = None,
|
|
|
|
|
- **kwargs,
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ audio_data: torch.Tensor,
|
|
|
|
|
+ template: torch.Tensor = None,
|
|
|
|
|
+ mask: torch.Tensor = None,
|
|
|
|
|
+ sample_rate: int = None,
|
|
|
|
|
+ n_quantizers: int = None,
|
|
|
|
|
+ **kwargs,
|
|
|
):
|
|
):
|
|
|
"""Model forward pass
|
|
"""Model forward pass
|
|
|
|
|
|