|
|
@@ -0,0 +1,1021 @@
|
|
|
+import math
|
|
|
+import typing as tp
|
|
|
+from dataclasses import dataclass
|
|
|
+from typing import List, Optional, Union
|
|
|
+
|
|
|
+import hydra
|
|
|
+import librosa
|
|
|
+import numpy as np
|
|
|
+import soundfile as sf
|
|
|
+import torch
|
|
|
+from audiotools import AudioSignal
|
|
|
+from audiotools.ml import BaseModel
|
|
|
+from dac.model.base import CodecMixin
|
|
|
+from dac.nn.layers import Snake1d, WNConv1d, WNConvTranspose1d
|
|
|
+from omegaconf import OmegaConf
|
|
|
+from torch import Tensor, nn
|
|
|
+from torch.nn import functional as F
|
|
|
+from torch.nn.utils.parametrizations import weight_norm
|
|
|
+from torch.nn.utils.parametrize import remove_parametrizations
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class VQResult:
|
|
|
+ z: torch.Tensor
|
|
|
+ codes: torch.Tensor
|
|
|
+ latents: torch.Tensor
|
|
|
+ codebook_loss: torch.Tensor
|
|
|
+ commitment_loss: torch.Tensor
|
|
|
+ semantic_distill_z: torch.Tensor | None = None
|
|
|
+
|
|
|
+
|
|
|
+def find_multiple(n: int, k: int) -> int:
|
|
|
+ if n % k == 0:
|
|
|
+ return n
|
|
|
+ return n + k - (n % k)
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class ModelArgs:
|
|
|
+ block_size: int = 2048
|
|
|
+ n_layer: int = 8
|
|
|
+ n_head: int = 8
|
|
|
+ dim: int = 512
|
|
|
+ intermediate_size: int = 1536
|
|
|
+ n_local_heads: int = -1
|
|
|
+ head_dim: int = 64
|
|
|
+ rope_base: float = 10000
|
|
|
+ norm_eps: float = 1e-5
|
|
|
+ dropout_rate: float = 0.1
|
|
|
+ attn_dropout_rate: float = 0.1
|
|
|
+ channels_first: bool = True # to be compatible with conv1d input/output
|
|
|
+ pos_embed_type: str = "rope" # can be "rope" or "conformer"
|
|
|
+ max_relative_position: int = 128 # for conformer-style relative position embedding
|
|
|
+
|
|
|
+ def __post_init__(self):
|
|
|
+ if self.n_local_heads == -1:
|
|
|
+ self.n_local_heads = self.n_head
|
|
|
+ if self.intermediate_size is None:
|
|
|
+ hidden_dim = 4 * self.dim
|
|
|
+ n_hidden = int(2 * hidden_dim / 3)
|
|
|
+ self.intermediate_size = find_multiple(n_hidden, 256)
|
|
|
+ assert self.pos_embed_type in [
|
|
|
+ "rope",
|
|
|
+ "conformer",
|
|
|
+ ], "pos_embed_type must be either 'rope' or 'conformer'"
|
|
|
+
|
|
|
+
|
|
|
+class KVCache(nn.Module):
|
|
|
+ def __init__(
|
|
|
+ self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+ cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
|
|
|
+ self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
|
|
|
+ self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
|
|
|
+
|
|
|
+ def update(self, input_pos, k_val, v_val):
|
|
|
+ # input_pos: [S], k_val: [B, H, S, D]
|
|
|
+ assert input_pos.shape[0] == k_val.shape[2]
|
|
|
+
|
|
|
+ k_out = self.k_cache
|
|
|
+ v_out = self.v_cache
|
|
|
+ k_out[:, :, input_pos] = k_val
|
|
|
+ v_out[:, :, input_pos] = v_val
|
|
|
+
|
|
|
+ return (
|
|
|
+ k_out[:, :, : input_pos.max() + 1, :],
|
|
|
+ v_out[:, :, : input_pos.max() + 1, :],
|
|
|
+ )
|
|
|
+
|
|
|
+ def clear_cache(self, prompt_len):
|
|
|
+ self.k_cache[:, :, prompt_len:, :].fill_(0)
|
|
|
+ self.v_cache[:, :, prompt_len:, :].fill_(0)
|
|
|
+
|
|
|
+
|
|
|
+class Transformer(nn.Module):
|
|
|
+ def __init__(self, config: ModelArgs) -> None:
|
|
|
+ super().__init__()
|
|
|
+ self.config = config
|
|
|
+
|
|
|
+ self.layers = nn.ModuleList(
|
|
|
+ TransformerBlock(config) for _ in range(config.n_layer)
|
|
|
+ )
|
|
|
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
|
|
|
+
|
|
|
+ # Only compute RoPE frequencies if using RoPE
|
|
|
+ if config.pos_embed_type == "rope":
|
|
|
+ freqs_cis = precompute_freqs_cis(
|
|
|
+ self.config.block_size, self.config.head_dim, self.config.rope_base
|
|
|
+ )
|
|
|
+ self.register_buffer("freqs_cis", freqs_cis)
|
|
|
+ else:
|
|
|
+ self.register_buffer("freqs_cis", None)
|
|
|
+
|
|
|
+ causal_mask = torch.tril(
|
|
|
+ torch.ones(self.config.block_size, self.config.block_size, dtype=torch.bool)
|
|
|
+ )
|
|
|
+ self.register_buffer("causal_mask", causal_mask)
|
|
|
+
|
|
|
+ self.max_batch_size = -1
|
|
|
+ self.max_seq_length = -1
|
|
|
+ self.use_kv_cache = False
|
|
|
+
|
|
|
+ def setup_caches(self, max_batch_size, max_seq_length):
|
|
|
+ """
|
|
|
+ This method will only be called during inference when using KV cache.
|
|
|
+ """
|
|
|
+ head_dim = self.config.dim // self.config.n_head
|
|
|
+ max_seq_length = find_multiple(max_seq_length, 8)
|
|
|
+ self.max_seq_length = max_seq_length
|
|
|
+ self.max_batch_size = max_batch_size
|
|
|
+ dtype = self.norm.weight.dtype
|
|
|
+ device = self.norm.weight.device
|
|
|
+
|
|
|
+ for b in self.layers:
|
|
|
+ b.attention.kv_cache = KVCache(
|
|
|
+ max_batch_size,
|
|
|
+ max_seq_length,
|
|
|
+ self.config.n_local_heads,
|
|
|
+ head_dim,
|
|
|
+ dtype,
|
|
|
+ ).to(device)
|
|
|
+
|
|
|
+ self.use_kv_cache = True
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ x: Tensor,
|
|
|
+ input_pos: Optional[Tensor] = None,
|
|
|
+ mask: Optional[Tensor] = None,
|
|
|
+ ) -> Tensor:
|
|
|
+ if self.config.pos_embed_type == "rope":
|
|
|
+ assert (
|
|
|
+ self.freqs_cis is not None
|
|
|
+ ), "RoPE frequencies must be initialized for RoPE positional embedding"
|
|
|
+ freqs_cis = self.freqs_cis[input_pos]
|
|
|
+ else:
|
|
|
+ freqs_cis = None
|
|
|
+
|
|
|
+ if mask is None: # in case of non-causal model
|
|
|
+ if not self.training and self.use_kv_cache:
|
|
|
+ mask = self.causal_mask[None, None, input_pos]
|
|
|
+ mask = mask[..., : input_pos.max() + 1]
|
|
|
+ else:
|
|
|
+ mask = self.causal_mask[None, None, input_pos]
|
|
|
+ mask = mask[..., input_pos]
|
|
|
+
|
|
|
+ for i, layer in enumerate(self.layers):
|
|
|
+ x = layer(x, input_pos, freqs_cis, mask)
|
|
|
+ x = self.norm(x)
|
|
|
+ return x
|
|
|
+
|
|
|
+
|
|
|
+class TransformerBlock(nn.Module):
|
|
|
+ def __init__(self, config: ModelArgs) -> None:
|
|
|
+ super().__init__()
|
|
|
+ self.attention = Attention(config)
|
|
|
+ self.feed_forward = FeedForward(config)
|
|
|
+ self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
|
|
+ self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
|
|
+ self.attention_layer_scale = LayerScale(config.dim, inplace=True)
|
|
|
+ self.ffn_layer_scale = LayerScale(config.dim, inplace=True)
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ x: Tensor,
|
|
|
+ input_pos: Tensor,
|
|
|
+ freqs_cis: Tensor,
|
|
|
+ mask: Tensor,
|
|
|
+ ) -> Tensor:
|
|
|
+ h = x + self.attention_layer_scale(
|
|
|
+ self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
|
|
|
+ )
|
|
|
+ out = h + self.ffn_layer_scale(self.feed_forward(self.ffn_norm(h)))
|
|
|
+ return out
|
|
|
+
|
|
|
+
|
|
|
+class Attention(nn.Module):
|
|
|
+ def __init__(self, config: ModelArgs):
|
|
|
+ super().__init__()
|
|
|
+ assert config.dim % config.n_head == 0
|
|
|
+
|
|
|
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
|
|
|
+ # key, query, value projections for all heads, but in a batch
|
|
|
+ self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
|
|
|
+ self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
|
|
|
+ self.kv_cache = None
|
|
|
+
|
|
|
+ self.n_head = config.n_head
|
|
|
+ self.head_dim = config.head_dim
|
|
|
+ self.n_local_heads = config.n_local_heads
|
|
|
+ self.dim = config.dim
|
|
|
+ self.attn_dropout_rate = config.attn_dropout_rate
|
|
|
+ self.pos_embed_type = config.pos_embed_type
|
|
|
+
|
|
|
+ # Add relative position embedding for conformer-style
|
|
|
+ if self.pos_embed_type == "conformer":
|
|
|
+ self.max_relative_position = config.max_relative_position
|
|
|
+ num_pos_embeddings = 2 * config.max_relative_position + 1
|
|
|
+ self.rel_pos_embeddings = nn.Parameter(
|
|
|
+ torch.zeros(num_pos_embeddings, self.head_dim)
|
|
|
+ )
|
|
|
+ nn.init.normal_(self.rel_pos_embeddings, mean=0.0, std=0.02)
|
|
|
+
|
|
|
+ def _compute_conformer_pos_scores(self, q: Tensor, seqlen: int) -> Tensor:
|
|
|
+ # q: [B, H, S, D]
|
|
|
+ # Returns: [B, H, S, S]
|
|
|
+ positions = torch.arange(seqlen, device=q.device)
|
|
|
+ relative_positions = positions.unsqueeze(1) - positions.unsqueeze(0) # [S, S]
|
|
|
+ relative_positions = torch.clamp(
|
|
|
+ relative_positions + self.max_relative_position,
|
|
|
+ 0,
|
|
|
+ 2 * self.max_relative_position,
|
|
|
+ )
|
|
|
+ rel_embeddings = self.rel_pos_embeddings[relative_positions] # [S, S, D]
|
|
|
+
|
|
|
+ # Compute attention scores with relative position embeddings
|
|
|
+ q = q.transpose(1, 2) # [B, S, H, D]
|
|
|
+ rel_logits = torch.matmul(q, rel_embeddings.transpose(-2, -1)) # [B, S, H, S]
|
|
|
+ rel_logits = rel_logits.transpose(1, 2) # [B, H, S, S]
|
|
|
+ return rel_logits
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ x: Tensor,
|
|
|
+ freqs_cis: Tensor,
|
|
|
+ mask: Tensor,
|
|
|
+ input_pos: Optional[Tensor] = None,
|
|
|
+ ) -> Tensor:
|
|
|
+ bsz, seqlen, _ = x.shape
|
|
|
+
|
|
|
+ kv_size = self.n_local_heads * self.head_dim
|
|
|
+ q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
|
|
|
+ context_seqlen = seqlen
|
|
|
+
|
|
|
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
|
|
+ k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
|
|
+ v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
|
|
|
+
|
|
|
+ if self.pos_embed_type == "rope":
|
|
|
+ q = apply_rotary_emb(q, freqs_cis)
|
|
|
+ k = apply_rotary_emb(k, freqs_cis)
|
|
|
+
|
|
|
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
|
|
+
|
|
|
+ if self.kv_cache is not None:
|
|
|
+ k, v = self.kv_cache.update(input_pos, k, v)
|
|
|
+
|
|
|
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
|
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
|
+
|
|
|
+ if self.pos_embed_type == "conformer":
|
|
|
+ # Compute attention scores
|
|
|
+ scale = 1.0 / math.sqrt(self.head_dim)
|
|
|
+ scores = torch.matmul(q, k.transpose(-2, -1)) * scale
|
|
|
+
|
|
|
+ # Add relative position embeddings for conformer-style
|
|
|
+ rel_scores = self._compute_conformer_pos_scores(q, seqlen)
|
|
|
+ scores = scores + rel_scores
|
|
|
+
|
|
|
+ # Apply attention
|
|
|
+ if mask is not None:
|
|
|
+ scores = scores.masked_fill(~mask, float("-inf"))
|
|
|
+
|
|
|
+ attn = F.softmax(scores, dim=-1)
|
|
|
+ if self.attn_dropout_rate > 0 and self.training:
|
|
|
+ attn = F.dropout(attn, p=self.attn_dropout_rate)
|
|
|
+
|
|
|
+ y = torch.matmul(attn, v)
|
|
|
+ else:
|
|
|
+ y = F.scaled_dot_product_attention(
|
|
|
+ q,
|
|
|
+ k,
|
|
|
+ v,
|
|
|
+ dropout_p=self.attn_dropout_rate if self.training else 0.0,
|
|
|
+ attn_mask=mask,
|
|
|
+ )
|
|
|
+ # is_causal=True)
|
|
|
+ y = (
|
|
|
+ y.transpose(1, 2)
|
|
|
+ .contiguous()
|
|
|
+ .view(bsz, seqlen, self.head_dim * self.n_head)
|
|
|
+ )
|
|
|
+ y = self.wo(y)
|
|
|
+ return y
|
|
|
+
|
|
|
+
|
|
|
+class FeedForward(nn.Module):
|
|
|
+ def __init__(self, config: ModelArgs) -> None:
|
|
|
+ super().__init__()
|
|
|
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
|
|
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
|
|
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
|
|
|
+ self.dropout = nn.Dropout(config.dropout_rate)
|
|
|
+
|
|
|
+ def forward(self, x: Tensor) -> Tensor:
|
|
|
+ return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
|
|
|
+
|
|
|
+
|
|
|
+class RMSNorm(nn.Module):
|
|
|
+ def __init__(self, dim: int, eps: float = 1e-5):
|
|
|
+ super().__init__()
|
|
|
+ self.eps = eps
|
|
|
+ self.weight = nn.Parameter(torch.ones(dim))
|
|
|
+
|
|
|
+ def _norm(self, x):
|
|
|
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
|
|
|
+
|
|
|
+ def forward(self, x: Tensor) -> Tensor:
|
|
|
+ output = self._norm(x.float()).type_as(x)
|
|
|
+ return output * self.weight
|
|
|
+
|
|
|
+
|
|
|
+class LayerScale(nn.Module):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ dim: int,
|
|
|
+ init_values: Union[float, Tensor] = 1e-2,
|
|
|
+ inplace: bool = False,
|
|
|
+ ) -> None:
|
|
|
+ super().__init__()
|
|
|
+ self.inplace = inplace
|
|
|
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
|
|
+
|
|
|
+ def forward(self, x: Tensor) -> Tensor:
|
|
|
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
|
|
+
|
|
|
+
|
|
|
+class WindowLimitedTransformer(Transformer):
|
|
|
+ """
|
|
|
+ Transformer with window limited attention, causal.
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ config: ModelArgs,
|
|
|
+ input_dim: int = 512,
|
|
|
+ window_size: Optional[int] = None,
|
|
|
+ causal: bool = True,
|
|
|
+ look_ahead_conv: nn.Module = None,
|
|
|
+ ):
|
|
|
+ super().__init__(config)
|
|
|
+ self.window_size = window_size
|
|
|
+ self.causal = causal
|
|
|
+ self.channels_first = config.channels_first
|
|
|
+ self.look_ahead_conv = (
|
|
|
+ look_ahead_conv if look_ahead_conv is not None else nn.Identity()
|
|
|
+ )
|
|
|
+ self.input_proj = (
|
|
|
+ nn.Linear(input_dim, config.dim)
|
|
|
+ if input_dim != config.dim
|
|
|
+ else nn.Identity()
|
|
|
+ )
|
|
|
+ self.output_proj = (
|
|
|
+ nn.Linear(config.dim, input_dim)
|
|
|
+ if input_dim != config.dim
|
|
|
+ else nn.Identity()
|
|
|
+ )
|
|
|
+
|
|
|
+ def make_window_limited_mask(
|
|
|
+ self,
|
|
|
+ max_length: int,
|
|
|
+ x_lens: Optional[Tensor] = None,
|
|
|
+ ) -> Tensor:
|
|
|
+ """
|
|
|
+ Make mask to form window limited attention.
|
|
|
+ """
|
|
|
+ if self.causal:
|
|
|
+ mask = torch.tril(torch.ones(max_length, max_length))
|
|
|
+ row_indices = torch.arange(max_length).view(-1, 1)
|
|
|
+ window_size = self.window_size or max_length
|
|
|
+ valid_range = (row_indices - window_size + 1).clamp(min=0)
|
|
|
+ column_indices = torch.arange(max_length)
|
|
|
+ mask = (column_indices >= valid_range) & mask.bool()
|
|
|
+ else:
|
|
|
+ raise NotImplementedError
|
|
|
+ mask = mask.bool()[None, None]
|
|
|
+ return mask
|
|
|
+
|
|
|
+ def make_mask(
|
|
|
+ self,
|
|
|
+ max_length: int,
|
|
|
+ x_lens: Optional[Tensor] = None,
|
|
|
+ ) -> Tensor:
|
|
|
+ """
|
|
|
+ Make ordinary mask if window size is not specified.
|
|
|
+ """
|
|
|
+ if self.causal:
|
|
|
+ mask = torch.tril(torch.ones(max_length, max_length))
|
|
|
+ else:
|
|
|
+ mask = torch.ones(max_length, max_length)
|
|
|
+ mask = mask.bool()[None, None]
|
|
|
+ for i, x_len in enumerate(x_lens):
|
|
|
+ mask[:x_len, i] = 0
|
|
|
+ mask = mask.bool()[None, None]
|
|
|
+ return mask
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ x: Tensor,
|
|
|
+ x_lens: Optional[Tensor] = None,
|
|
|
+ ) -> Tensor:
|
|
|
+ if self.channels_first:
|
|
|
+ x = x.transpose(1, 2)
|
|
|
+ x = self.input_proj(x) # (B, T, D)
|
|
|
+ x = self.look_ahead_conv(x)
|
|
|
+ input_pos = torch.arange(x.shape[1], device=x.device)
|
|
|
+ # construct mask to form window limited attention
|
|
|
+ max_length = x.shape[1]
|
|
|
+ if self.window_size is not None:
|
|
|
+ mask = self.make_window_limited_mask(max_length, x_lens)
|
|
|
+ else:
|
|
|
+ mask = self.make_mask(max_length, x_lens)
|
|
|
+ mask = mask.to(x.device)
|
|
|
+ x = super().forward(x, input_pos, mask)
|
|
|
+ x = self.output_proj(x) # (B, T, D)
|
|
|
+ if self.channels_first:
|
|
|
+ x = x.transpose(1, 2)
|
|
|
+ return x
|
|
|
+
|
|
|
+
|
|
|
+def precompute_freqs_cis(
|
|
|
+ seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16
|
|
|
+) -> Tensor:
|
|
|
+ freqs = 1.0 / (
|
|
|
+ base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
|
|
|
+ )
|
|
|
+ t = torch.arange(seq_len, device=freqs.device)
|
|
|
+ freqs = torch.outer(t, freqs)
|
|
|
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
|
|
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
|
|
|
+ return cache.to(dtype=dtype)
|
|
|
+
|
|
|
+
|
|
|
+def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
|
|
|
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
|
|
|
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
|
|
|
+ x_out2 = torch.stack(
|
|
|
+ [
|
|
|
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
|
|
|
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
|
|
|
+ ],
|
|
|
+ -1,
|
|
|
+ )
|
|
|
+
|
|
|
+ x_out2 = x_out2.flatten(3)
|
|
|
+ return x_out2.type_as(x)
|
|
|
+
|
|
|
+
|
|
|
+def init_weights(m):
|
|
|
+ if isinstance(m, nn.Conv1d):
|
|
|
+ nn.init.trunc_normal_(m.weight, std=0.02)
|
|
|
+ nn.init.constant_(m.bias, 0)
|
|
|
+
|
|
|
+
|
|
|
+def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
|
|
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
|
|
|
+ padding_left, padding_right = paddings
|
|
|
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
|
|
+ assert (padding_left + padding_right) <= x.shape[-1]
|
|
|
+ end = x.shape[-1] - padding_right
|
|
|
+ return x[..., padding_left:end]
|
|
|
+
|
|
|
+
|
|
|
+def get_extra_padding_for_conv1d(
|
|
|
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
|
|
+) -> int:
|
|
|
+ """See `pad_for_conv1d`."""
|
|
|
+ length = x.shape[-1]
|
|
|
+ n_frames = (length - kernel_size + padding_total) / stride + 1
|
|
|
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
|
|
+ return ideal_length - length
|
|
|
+
|
|
|
+
|
|
|
+def pad1d(
|
|
|
+ x: torch.Tensor,
|
|
|
+ paddings: tp.Tuple[int, int],
|
|
|
+ mode: str = "zeros",
|
|
|
+ value: float = 0.0,
|
|
|
+):
|
|
|
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
|
|
+ If this is the case, we insert extra 0 padding to the right
|
|
|
+ before the reflection happen.
|
|
|
+ """
|
|
|
+ length = x.shape[-1]
|
|
|
+ padding_left, padding_right = paddings
|
|
|
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
|
|
+ if mode == "reflect":
|
|
|
+ max_pad = max(padding_left, padding_right)
|
|
|
+ extra_pad = 0
|
|
|
+ if length <= max_pad:
|
|
|
+ extra_pad = max_pad - length + 1
|
|
|
+ x = F.pad(x, (0, extra_pad))
|
|
|
+ padded = F.pad(x, paddings, mode, value)
|
|
|
+ end = padded.shape[-1] - extra_pad
|
|
|
+ return padded[..., :end]
|
|
|
+ else:
|
|
|
+ return F.pad(x, paddings, mode, value)
|
|
|
+
|
|
|
+
|
|
|
+class CausalConvNet(nn.Module):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ in_channels,
|
|
|
+ out_channels,
|
|
|
+ kernel_size,
|
|
|
+ dilation=1,
|
|
|
+ stride=1,
|
|
|
+ groups=1,
|
|
|
+ padding=None,
|
|
|
+ ):
|
|
|
+ super(CausalConvNet, self).__init__()
|
|
|
+ self.conv = nn.Conv1d(
|
|
|
+ in_channels,
|
|
|
+ out_channels,
|
|
|
+ kernel_size,
|
|
|
+ stride=stride,
|
|
|
+ dilation=dilation,
|
|
|
+ groups=groups,
|
|
|
+ )
|
|
|
+ self.stride = stride
|
|
|
+ self.kernel_size = (kernel_size - 1) * dilation + 1
|
|
|
+ self.dilation = dilation
|
|
|
+ self.padding = self.kernel_size - self.stride
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ pad = self.padding
|
|
|
+ extra_padding = get_extra_padding_for_conv1d(
|
|
|
+ x, self.kernel_size, self.stride, pad
|
|
|
+ )
|
|
|
+ x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
|
|
|
+ return self.conv(x).contiguous()
|
|
|
+
|
|
|
+ def weight_norm(self, name="weight", dim=0):
|
|
|
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
|
|
|
+ return self
|
|
|
+
|
|
|
+ def remove_weight_norm(self):
|
|
|
+ self.conv = remove_parametrizations(self.conv)
|
|
|
+ return self
|
|
|
+
|
|
|
+
|
|
|
+class CausalTransConvNet(nn.Module):
|
|
|
+ def __init__(
|
|
|
+ self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
|
|
|
+ ):
|
|
|
+ super(CausalTransConvNet, self).__init__()
|
|
|
+ self.conv = nn.ConvTranspose1d(
|
|
|
+ in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
|
|
|
+ )
|
|
|
+ self.stride = stride
|
|
|
+ self.kernel_size = kernel_size
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ x = self.conv(x)
|
|
|
+ pad = self.kernel_size - self.stride
|
|
|
+ padding_right = math.ceil(pad)
|
|
|
+ padding_left = pad - padding_right
|
|
|
+ x = unpad1d(x, (padding_left, padding_right))
|
|
|
+ return x.contiguous()
|
|
|
+
|
|
|
+ def weight_norm(self, name="weight", dim=0):
|
|
|
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
|
|
|
+ return self
|
|
|
+
|
|
|
+ def remove_weight_norm(self):
|
|
|
+ self.conv = remove_parametrizations(self.conv)
|
|
|
+ return self
|
|
|
+
|
|
|
+
|
|
|
+def CausalWNConv1d(*args, **kwargs):
|
|
|
+ return CausalConvNet(*args, **kwargs).weight_norm()
|
|
|
+
|
|
|
+
|
|
|
+def CausalWNConvTranspose1d(*args, **kwargs):
|
|
|
+ return CausalTransConvNet(*args, **kwargs).weight_norm()
|
|
|
+
|
|
|
+
|
|
|
+class ResidualUnit(nn.Module):
|
|
|
+ def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False):
|
|
|
+ super().__init__()
|
|
|
+ conv_class = CausalWNConv1d if causal else WNConv1d
|
|
|
+ pad = ((7 - 1) * dilation) // 2
|
|
|
+ self.block = nn.Sequential(
|
|
|
+ Snake1d(dim),
|
|
|
+ conv_class(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
|
|
+ Snake1d(dim),
|
|
|
+ conv_class(dim, dim, kernel_size=1),
|
|
|
+ )
|
|
|
+ self.causal = causal
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ y = self.block(x)
|
|
|
+ pad = x.shape[-1] - y.shape[-1]
|
|
|
+ if pad > 0:
|
|
|
+ if self.causal:
|
|
|
+ x = x[..., :-pad]
|
|
|
+ else:
|
|
|
+ x = x[..., pad // 2 : -pad // 2]
|
|
|
+ return x + y
|
|
|
+
|
|
|
+
|
|
|
+class EncoderBlock(nn.Module):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ dim: int = 16,
|
|
|
+ stride: int = 1,
|
|
|
+ causal: bool = False,
|
|
|
+ n_t_layer: int = 0,
|
|
|
+ transformer_general_config=None,
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+ conv_class = CausalWNConv1d if causal else WNConv1d
|
|
|
+ transformer_module = (
|
|
|
+ nn.Identity()
|
|
|
+ if n_t_layer == 0
|
|
|
+ else (
|
|
|
+ WindowLimitedTransformer(
|
|
|
+ causal=causal,
|
|
|
+ input_dim=dim,
|
|
|
+ window_size=512,
|
|
|
+ config=transformer_general_config(
|
|
|
+ n_layer=n_t_layer,
|
|
|
+ n_head=dim // 64,
|
|
|
+ dim=dim,
|
|
|
+ intermediate_size=dim * 3,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ )
|
|
|
+ )
|
|
|
+ self.block = nn.Sequential(
|
|
|
+ ResidualUnit(dim // 2, dilation=1, causal=causal),
|
|
|
+ ResidualUnit(dim // 2, dilation=3, causal=causal),
|
|
|
+ ResidualUnit(dim // 2, dilation=9, causal=causal),
|
|
|
+ Snake1d(dim // 2),
|
|
|
+ conv_class(
|
|
|
+ dim // 2,
|
|
|
+ dim,
|
|
|
+ kernel_size=2 * stride,
|
|
|
+ stride=stride,
|
|
|
+ padding=math.ceil(stride / 2),
|
|
|
+ ),
|
|
|
+ transformer_module,
|
|
|
+ )
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ return self.block(x)
|
|
|
+
|
|
|
+
|
|
|
+class Encoder(nn.Module):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ d_model: int = 64,
|
|
|
+ strides: list = [2, 4, 8, 8],
|
|
|
+ d_latent: int = 64,
|
|
|
+ n_transformer_layers: list = [0, 0, 4, 4],
|
|
|
+ transformer_general_config: ModelArgs = None,
|
|
|
+ causal: bool = False,
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+ conv_class = CausalWNConv1d if causal else WNConv1d
|
|
|
+ # Create first convolution
|
|
|
+ self.block = [conv_class(1, d_model, kernel_size=7, padding=3)]
|
|
|
+
|
|
|
+ # Create EncoderBlocks that double channels as they downsample by `stride`
|
|
|
+ for stride, n_t_layer in zip(strides, n_transformer_layers):
|
|
|
+ d_model *= 2
|
|
|
+ self.block += [
|
|
|
+ EncoderBlock(
|
|
|
+ d_model,
|
|
|
+ stride=stride,
|
|
|
+ causal=causal,
|
|
|
+ n_t_layer=n_t_layer,
|
|
|
+ transformer_general_config=transformer_general_config,
|
|
|
+ )
|
|
|
+ ]
|
|
|
+
|
|
|
+ # Create last convolution
|
|
|
+ self.block += [
|
|
|
+ Snake1d(d_model),
|
|
|
+ conv_class(d_model, d_latent, kernel_size=3, padding=1),
|
|
|
+ ]
|
|
|
+
|
|
|
+ # Wrap black into nn.Sequential
|
|
|
+ self.block = nn.Sequential(*self.block)
|
|
|
+ self.enc_dim = d_model
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ return self.block(x)
|
|
|
+
|
|
|
+
|
|
|
+class DecoderBlock(nn.Module):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ input_dim: int = 16,
|
|
|
+ output_dim: int = 8,
|
|
|
+ stride: int = 1,
|
|
|
+ causal: bool = False,
|
|
|
+ n_t_layer: int = 0,
|
|
|
+ transformer_general_config=None,
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+ conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d
|
|
|
+ transformer_module = (
|
|
|
+ nn.Identity()
|
|
|
+ if n_t_layer == 0
|
|
|
+ else (
|
|
|
+ WindowLimitedTransformer(
|
|
|
+ causal=causal,
|
|
|
+ input_dim=input_dim,
|
|
|
+ window_size=None,
|
|
|
+ config=transformer_general_config(
|
|
|
+ n_layer=n_t_layer,
|
|
|
+ n_head=input_dim // 64,
|
|
|
+ dim=input_dim,
|
|
|
+ intermediate_size=input_dim * 3,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ )
|
|
|
+ )
|
|
|
+ self.block = nn.Sequential(
|
|
|
+ # transformer_module,
|
|
|
+ Snake1d(input_dim),
|
|
|
+ conv_trans_class(
|
|
|
+ input_dim,
|
|
|
+ output_dim,
|
|
|
+ kernel_size=2 * stride,
|
|
|
+ stride=stride,
|
|
|
+ padding=math.ceil(stride / 2),
|
|
|
+ ),
|
|
|
+ ResidualUnit(output_dim, dilation=1, causal=causal),
|
|
|
+ ResidualUnit(output_dim, dilation=3, causal=causal),
|
|
|
+ ResidualUnit(output_dim, dilation=9, causal=causal),
|
|
|
+ )
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ return self.block(x)
|
|
|
+
|
|
|
+
|
|
|
+class Decoder(nn.Module):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ input_channel,
|
|
|
+ channels,
|
|
|
+ rates,
|
|
|
+ d_out: int = 1,
|
|
|
+ causal: bool = False,
|
|
|
+ n_transformer_layers: list = [0, 0, 0, 0],
|
|
|
+ transformer_general_config=None,
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+ conv_class = CausalWNConv1d if causal else WNConv1d
|
|
|
+ # Add first conv layer
|
|
|
+ layers = [conv_class(input_channel, channels, kernel_size=7, padding=3)]
|
|
|
+
|
|
|
+ # Add upsampling + MRF blocks
|
|
|
+ for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)):
|
|
|
+ input_dim = channels // 2**i
|
|
|
+ output_dim = channels // 2 ** (i + 1)
|
|
|
+ layers += [
|
|
|
+ DecoderBlock(
|
|
|
+ input_dim,
|
|
|
+ output_dim,
|
|
|
+ stride,
|
|
|
+ causal=causal,
|
|
|
+ n_t_layer=n_t_layer,
|
|
|
+ transformer_general_config=transformer_general_config,
|
|
|
+ )
|
|
|
+ ]
|
|
|
+
|
|
|
+ # Add final conv layer
|
|
|
+ layers += [
|
|
|
+ Snake1d(output_dim),
|
|
|
+ conv_class(output_dim, d_out, kernel_size=7, padding=3),
|
|
|
+ nn.Tanh(),
|
|
|
+ ]
|
|
|
+
|
|
|
+ self.model = nn.Sequential(*layers)
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ return self.model(x)
|
|
|
+
|
|
|
+
|
|
|
+class DAC(BaseModel, CodecMixin):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ encoder_dim: int = 64,
|
|
|
+ encoder_rates: List[int] = [2, 4, 8, 8],
|
|
|
+ latent_dim: int = None,
|
|
|
+ decoder_dim: int = 1536,
|
|
|
+ decoder_rates: List[int] = [8, 8, 4, 2],
|
|
|
+ quantizer: torch.nn.Module = None,
|
|
|
+ sample_rate: int = 44100,
|
|
|
+ causal: bool = True,
|
|
|
+ encoder_transformer_layers: List[int] = [0, 0, 0, 0],
|
|
|
+ decoder_transformer_layers: List[int] = [0, 0, 0, 0],
|
|
|
+ transformer_general_config=None,
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+
|
|
|
+ self.encoder_dim = encoder_dim
|
|
|
+ self.encoder_rates = encoder_rates
|
|
|
+ self.decoder_dim = decoder_dim
|
|
|
+ self.decoder_rates = decoder_rates
|
|
|
+ self.sample_rate = sample_rate
|
|
|
+
|
|
|
+ if latent_dim is None:
|
|
|
+ latent_dim = encoder_dim * (2 ** len(encoder_rates))
|
|
|
+
|
|
|
+ self.latent_dim = latent_dim
|
|
|
+
|
|
|
+ self.hop_length = np.prod(encoder_rates)
|
|
|
+ self.encoder = Encoder(
|
|
|
+ encoder_dim,
|
|
|
+ encoder_rates,
|
|
|
+ latent_dim,
|
|
|
+ causal=causal,
|
|
|
+ n_transformer_layers=encoder_transformer_layers,
|
|
|
+ transformer_general_config=transformer_general_config,
|
|
|
+ )
|
|
|
+
|
|
|
+ self.quantizer = quantizer
|
|
|
+
|
|
|
+ self.decoder = Decoder(
|
|
|
+ latent_dim,
|
|
|
+ decoder_dim,
|
|
|
+ decoder_rates,
|
|
|
+ causal=causal,
|
|
|
+ n_transformer_layers=decoder_transformer_layers,
|
|
|
+ transformer_general_config=transformer_general_config,
|
|
|
+ )
|
|
|
+ self.sample_rate = sample_rate
|
|
|
+ self.apply(init_weights)
|
|
|
+
|
|
|
+ self.delay = self.get_delay()
|
|
|
+
|
|
|
+ self.frame_length = self.hop_length * 4
|
|
|
+
|
|
|
+ def preprocess(self, audio_data, sample_rate):
|
|
|
+ if sample_rate is None:
|
|
|
+ sample_rate = self.sample_rate
|
|
|
+ assert sample_rate == self.sample_rate
|
|
|
+
|
|
|
+ length = audio_data.shape[-1]
|
|
|
+ right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
|
|
|
+ audio_data = nn.functional.pad(audio_data, (0, right_pad))
|
|
|
+
|
|
|
+ return audio_data
|
|
|
+
|
|
|
+ def encode(
|
|
|
+ self,
|
|
|
+ audio_data: torch.Tensor,
|
|
|
+ audio_lengths: torch.Tensor = None,
|
|
|
+ n_quantizers: int = None,
|
|
|
+ **kwargs,
|
|
|
+ ):
|
|
|
+ """Encode given audio data and return quantized latent codes
|
|
|
+
|
|
|
+ Parameters
|
|
|
+ ----------
|
|
|
+ audio_data : Tensor[B x T]
|
|
|
+ Audio data to encode
|
|
|
+ n_quantizers : int, optional
|
|
|
+ Number of quantizers to use, by default None
|
|
|
+ If None, all quantizers are used.
|
|
|
+
|
|
|
+ Returns
|
|
|
+ -------
|
|
|
+ dict
|
|
|
+ A dictionary with the following keys:
|
|
|
+ "z" : Tensor[B x D x T]
|
|
|
+ Quantized continuous representation of input
|
|
|
+ "codes" : Tensor[B x N x T]
|
|
|
+ Codebook indices for each codebook
|
|
|
+ (quantized discrete representation of input)
|
|
|
+ "latents" : Tensor[B x N*D x T]
|
|
|
+ Projected latents (continuous representation of input before quantization)
|
|
|
+ "vq/commitment_loss" : Tensor[1]
|
|
|
+ Commitment loss to train encoder to predict vectors closer to codebook
|
|
|
+ entries
|
|
|
+ "vq/codebook_loss" : Tensor[1]
|
|
|
+ Codebook loss to update the codebook
|
|
|
+ "length" : int
|
|
|
+ Number of samples in input audio
|
|
|
+ """
|
|
|
+ # pad to multiple of self.frame_length
|
|
|
+ if audio_data.ndim == 2:
|
|
|
+ audio_data = audio_data.unsqueeze(1)
|
|
|
+ # print(audio_data.shape)
|
|
|
+ length = audio_data.shape[-1]
|
|
|
+ right_pad = math.ceil(length / self.frame_length) * self.frame_length - length
|
|
|
+ audio_data = nn.functional.pad(audio_data, (0, right_pad))
|
|
|
+ if audio_lengths is None:
|
|
|
+ audio_lengths = torch.LongTensor([length + right_pad]).to(audio_data.device)
|
|
|
+
|
|
|
+ z = self.encoder(audio_data)
|
|
|
+ vq_results = self.quantizer(z, n_quantizers, **kwargs)
|
|
|
+ indices = vq_results.codes
|
|
|
+ indices_lens = torch.ceil(audio_lengths / self.frame_length).long()
|
|
|
+ return indices, indices_lens
|
|
|
+
|
|
|
+ def decode(self, indices: torch.Tensor, feature_lengths):
|
|
|
+ z = self.quantizer.decode(indices)
|
|
|
+ audio_lengths = feature_lengths * self.frame_length
|
|
|
+ return self.decoder(z), audio_lengths
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ audio_data: torch.Tensor,
|
|
|
+ template: torch.Tensor = None,
|
|
|
+ mask: torch.Tensor = None,
|
|
|
+ sample_rate: int = None,
|
|
|
+ n_quantizers: int = None,
|
|
|
+ **kwargs,
|
|
|
+ ):
|
|
|
+ """Model forward pass
|
|
|
+
|
|
|
+ Parameters
|
|
|
+ ----------
|
|
|
+ audio_data : Tensor[B x 1 x T]
|
|
|
+ Audio data to encode
|
|
|
+ sample_rate : int, optional
|
|
|
+ Sample rate of audio data in Hz, by default None
|
|
|
+ If None, defaults to `self.sample_rate`
|
|
|
+ n_quantizers : int, optional
|
|
|
+ Number of quantizers to use, by default None.
|
|
|
+ If None, all quantizers are used.
|
|
|
+
|
|
|
+ Returns
|
|
|
+ -------
|
|
|
+ dict
|
|
|
+ A dictionary with the following keys:
|
|
|
+ "z" : Tensor[B x D x T]
|
|
|
+ Quantized continuous representation of input
|
|
|
+ "codes" : Tensor[B x N x T]
|
|
|
+ Codebook indices for each codebook
|
|
|
+ (quantized discrete representation of input)
|
|
|
+ "latents" : Tensor[B x N*D x T]
|
|
|
+ Projected latents (continuous representation of input before quantization)
|
|
|
+ "vq/commitment_loss" : Tensor[1]
|
|
|
+ Commitment loss to train encoder to predict vectors closer to codebook
|
|
|
+ entries
|
|
|
+ "vq/codebook_loss" : Tensor[1]
|
|
|
+ Codebook loss to update the codebook
|
|
|
+ "length" : int
|
|
|
+ Number of samples in input audio
|
|
|
+ "audio" : Tensor[B x 1 x length]
|
|
|
+ Decoded audio data.
|
|
|
+ """
|
|
|
+ length = audio_data.shape[-1]
|
|
|
+ audio_data = self.preprocess(audio_data, sample_rate)
|
|
|
+ vq_results = self.encode(audio_data, n_quantizers, **kwargs)
|
|
|
+ z = vq_results[0] if isinstance(vq_results, tuple) else vq_results.z
|
|
|
+ x = self.decode(z)
|
|
|
+ return x[..., :length], vq_results
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+
|
|
|
+ def filter_state_dict_shapes(params, model):
|
|
|
+ model_state_dict = model.state_dict()
|
|
|
+ filtered_state_dict = {
|
|
|
+ k: v
|
|
|
+ for k, v in params.items()
|
|
|
+ if k in model_state_dict and v.shape == model_state_dict[k].shape
|
|
|
+ }
|
|
|
+ skipped_keys = set(params.keys()) - set(filtered_state_dict.keys())
|
|
|
+ if skipped_keys:
|
|
|
+ print(
|
|
|
+ f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
|
|
|
+ )
|
|
|
+ return filtered_state_dict, skipped_keys
|
|
|
+
|
|
|
+ model = hydra.utils.instantiate(
|
|
|
+ OmegaConf.load("fish_speech/configs/modded_dac_vq.yaml")
|
|
|
+ )
|
|
|
+ sd = torch.load("checkpoints/openaudio-s1-mini/firefly-gan-large.pth")
|
|
|
+ filtered_sd, skipped_keys = filter_state_dict_shapes(sd, model)
|
|
|
+ print(f"Skipped keys: {skipped_keys}")
|
|
|
+ model.load_state_dict(filtered_sd, strict=False)
|
|
|
+ model.eval()
|
|
|
+
|
|
|
+ src_audio_path = "./test.wav"
|
|
|
+ wave_np, _ = librosa.load(src_audio_path, sr=44100, mono=False)
|
|
|
+ if len(wave_np.shape) == 1:
|
|
|
+ wave_np = wave_np[None, :]
|
|
|
+ wave_tensor = torch.from_numpy(wave_np).unsqueeze(1)
|
|
|
+
|
|
|
+ with torch.no_grad():
|
|
|
+ # encode 返回 (indices, indices_lens)
|
|
|
+ indices, indices_lens = model.encode(wave_tensor)
|
|
|
+ print(f"Indices shape: {indices.shape}")
|
|
|
+ print(f"Indices lengths: {indices_lens}")
|
|
|
+
|
|
|
+ # decode 需要 indices 和 feature_lengths 两个参数
|
|
|
+ fake_audio, audio_lengths = model.decode(indices, indices_lens)
|
|
|
+ print(f"Decoded audio shape: {fake_audio.shape}")
|
|
|
+ print(f"Audio lengths: {audio_lengths}")
|
|
|
+
|
|
|
+ # 保存重建的音频
|
|
|
+ sf.write("fake.wav", fake_audio.squeeze(1).cpu().numpy().T, 44100)
|