modded_dac.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978
  1. import math
  2. import typing as tp
  3. from dataclasses import dataclass
  4. from typing import List, Optional, Union
  5. import hydra
  6. import librosa
  7. import numpy as np
  8. import soundfile as sf
  9. import torch
  10. from audiotools import AudioSignal
  11. from audiotools.ml import BaseModel
  12. from dac.model.base import CodecMixin
  13. from dac.nn.layers import Snake1d, WNConv1d, WNConvTranspose1d
  14. from omegaconf import OmegaConf
  15. from torch import Tensor, nn
  16. from torch.nn import functional as F
  17. from torch.nn.utils.parametrizations import weight_norm
  18. from torch.nn.utils.parametrize import remove_parametrizations
  19. @dataclass
  20. class VQResult:
  21. z: torch.Tensor
  22. codes: torch.Tensor
  23. latents: torch.Tensor
  24. codebook_loss: torch.Tensor
  25. commitment_loss: torch.Tensor
  26. semantic_distill_z: torch.Tensor | None = None
  27. def find_multiple(n: int, k: int) -> int:
  28. if n % k == 0:
  29. return n
  30. return n + k - (n % k)
  31. @dataclass
  32. class ModelArgs:
  33. block_size: int = 2048
  34. n_layer: int = 8
  35. n_head: int = 8
  36. dim: int = 512
  37. intermediate_size: int = 1536
  38. n_local_heads: int = -1
  39. head_dim: int = 64
  40. rope_base: float = 10000
  41. norm_eps: float = 1e-5
  42. dropout_rate: float = 0.1
  43. attn_dropout_rate: float = 0.1
  44. channels_first: bool = True # to be compatible with conv1d input/output
  45. pos_embed_type: str = "rope" # can be "rope" or "conformer"
  46. max_relative_position: int = 128 # for conformer-style relative position embedding
  47. def __post_init__(self):
  48. if self.n_local_heads == -1:
  49. self.n_local_heads = self.n_head
  50. if self.intermediate_size is None:
  51. hidden_dim = 4 * self.dim
  52. n_hidden = int(2 * hidden_dim / 3)
  53. self.intermediate_size = find_multiple(n_hidden, 256)
  54. assert self.pos_embed_type in [
  55. "rope",
  56. "conformer",
  57. ], "pos_embed_type must be either 'rope' or 'conformer'"
  58. class KVCache(nn.Module):
  59. def __init__(
  60. self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16
  61. ):
  62. super().__init__()
  63. cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
  64. self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
  65. self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
  66. def update(self, input_pos, k_val, v_val):
  67. # input_pos: [S], k_val: [B, H, S, D]
  68. assert input_pos.shape[0] == k_val.shape[2]
  69. k_out = self.k_cache
  70. v_out = self.v_cache
  71. k_out[:, :, input_pos] = k_val
  72. v_out[:, :, input_pos] = v_val
  73. return (
  74. k_out[:, :, : input_pos.max() + 1, :],
  75. v_out[:, :, : input_pos.max() + 1, :],
  76. )
  77. def clear_cache(self, prompt_len):
  78. self.k_cache[:, :, prompt_len:, :].fill_(0)
  79. self.v_cache[:, :, prompt_len:, :].fill_(0)
  80. class Transformer(nn.Module):
  81. def __init__(self, config: ModelArgs) -> None:
  82. super().__init__()
  83. self.config = config
  84. self.layers = nn.ModuleList(
  85. TransformerBlock(config) for _ in range(config.n_layer)
  86. )
  87. self.norm = RMSNorm(config.dim, eps=config.norm_eps)
  88. # Only compute RoPE frequencies if using RoPE
  89. if config.pos_embed_type == "rope":
  90. freqs_cis = precompute_freqs_cis(
  91. self.config.block_size, self.config.head_dim, self.config.rope_base
  92. )
  93. self.register_buffer("freqs_cis", freqs_cis)
  94. else:
  95. self.register_buffer("freqs_cis", None)
  96. causal_mask = torch.tril(
  97. torch.ones(self.config.block_size, self.config.block_size, dtype=torch.bool)
  98. )
  99. self.register_buffer("causal_mask", causal_mask)
  100. self.max_batch_size = -1
  101. self.max_seq_length = -1
  102. self.use_kv_cache = False
  103. def setup_caches(self, max_batch_size, max_seq_length):
  104. """
  105. This method will only be called during inference when using KV cache.
  106. """
  107. head_dim = self.config.dim // self.config.n_head
  108. max_seq_length = find_multiple(max_seq_length, 8)
  109. self.max_seq_length = max_seq_length
  110. self.max_batch_size = max_batch_size
  111. dtype = self.norm.weight.dtype
  112. device = self.norm.weight.device
  113. for b in self.layers:
  114. b.attention.kv_cache = KVCache(
  115. max_batch_size,
  116. max_seq_length,
  117. self.config.n_local_heads,
  118. head_dim,
  119. dtype,
  120. ).to(device)
  121. self.use_kv_cache = True
  122. def forward(
  123. self,
  124. x: Tensor,
  125. input_pos: Optional[Tensor] = None,
  126. mask: Optional[Tensor] = None,
  127. ) -> Tensor:
  128. if self.config.pos_embed_type == "rope":
  129. assert (
  130. self.freqs_cis is not None
  131. ), "RoPE frequencies must be initialized for RoPE positional embedding"
  132. freqs_cis = self.freqs_cis[input_pos]
  133. else:
  134. freqs_cis = None
  135. if mask is None: # in case of non-causal model
  136. if not self.training and self.use_kv_cache:
  137. mask = self.causal_mask[None, None, input_pos]
  138. mask = mask[..., : input_pos.max() + 1]
  139. else:
  140. mask = self.causal_mask[None, None, input_pos]
  141. mask = mask[..., input_pos]
  142. for i, layer in enumerate(self.layers):
  143. x = layer(x, input_pos, freqs_cis, mask)
  144. x = self.norm(x)
  145. return x
  146. class TransformerBlock(nn.Module):
  147. def __init__(self, config: ModelArgs) -> None:
  148. super().__init__()
  149. self.attention = Attention(config)
  150. self.feed_forward = FeedForward(config)
  151. self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
  152. self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
  153. self.attention_layer_scale = LayerScale(config.dim, inplace=True)
  154. self.ffn_layer_scale = LayerScale(config.dim, inplace=True)
  155. def forward(
  156. self,
  157. x: Tensor,
  158. input_pos: Tensor,
  159. freqs_cis: Tensor,
  160. mask: Tensor,
  161. ) -> Tensor:
  162. h = x + self.attention_layer_scale(
  163. self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
  164. )
  165. out = h + self.ffn_layer_scale(self.feed_forward(self.ffn_norm(h)))
  166. return out
  167. class Attention(nn.Module):
  168. def __init__(self, config: ModelArgs):
  169. super().__init__()
  170. assert config.dim % config.n_head == 0
  171. total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
  172. # key, query, value projections for all heads, but in a batch
  173. self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
  174. self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
  175. self.kv_cache = None
  176. self.n_head = config.n_head
  177. self.head_dim = config.head_dim
  178. self.n_local_heads = config.n_local_heads
  179. self.dim = config.dim
  180. self.attn_dropout_rate = config.attn_dropout_rate
  181. self.pos_embed_type = config.pos_embed_type
  182. # Add relative position embedding for conformer-style
  183. if self.pos_embed_type == "conformer":
  184. self.max_relative_position = config.max_relative_position
  185. num_pos_embeddings = 2 * config.max_relative_position + 1
  186. self.rel_pos_embeddings = nn.Parameter(
  187. torch.zeros(num_pos_embeddings, self.head_dim)
  188. )
  189. nn.init.normal_(self.rel_pos_embeddings, mean=0.0, std=0.02)
  190. def _compute_conformer_pos_scores(self, q: Tensor, seqlen: int) -> Tensor:
  191. # q: [B, H, S, D]
  192. # Returns: [B, H, S, S]
  193. positions = torch.arange(seqlen, device=q.device)
  194. relative_positions = positions.unsqueeze(1) - positions.unsqueeze(0) # [S, S]
  195. relative_positions = torch.clamp(
  196. relative_positions + self.max_relative_position,
  197. 0,
  198. 2 * self.max_relative_position,
  199. )
  200. rel_embeddings = self.rel_pos_embeddings[relative_positions] # [S, S, D]
  201. # Compute attention scores with relative position embeddings
  202. q = q.transpose(1, 2) # [B, S, H, D]
  203. rel_logits = torch.matmul(q, rel_embeddings.transpose(-2, -1)) # [B, S, H, S]
  204. rel_logits = rel_logits.transpose(1, 2) # [B, H, S, S]
  205. return rel_logits
  206. def forward(
  207. self,
  208. x: Tensor,
  209. freqs_cis: Tensor,
  210. mask: Tensor,
  211. input_pos: Optional[Tensor] = None,
  212. ) -> Tensor:
  213. bsz, seqlen, _ = x.shape
  214. kv_size = self.n_local_heads * self.head_dim
  215. q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
  216. context_seqlen = seqlen
  217. q = q.view(bsz, seqlen, self.n_head, self.head_dim)
  218. k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
  219. v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
  220. if self.pos_embed_type == "rope":
  221. q = apply_rotary_emb(q, freqs_cis)
  222. k = apply_rotary_emb(k, freqs_cis)
  223. q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
  224. if self.kv_cache is not None:
  225. k, v = self.kv_cache.update(input_pos, k, v)
  226. k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  227. v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  228. if self.pos_embed_type == "conformer":
  229. # Compute attention scores
  230. scale = 1.0 / math.sqrt(self.head_dim)
  231. scores = torch.matmul(q, k.transpose(-2, -1)) * scale
  232. # Add relative position embeddings for conformer-style
  233. rel_scores = self._compute_conformer_pos_scores(q, seqlen)
  234. scores = scores + rel_scores
  235. # Apply attention
  236. if mask is not None:
  237. scores = scores.masked_fill(~mask, float("-inf"))
  238. attn = F.softmax(scores, dim=-1)
  239. if self.attn_dropout_rate > 0 and self.training:
  240. attn = F.dropout(attn, p=self.attn_dropout_rate)
  241. y = torch.matmul(attn, v)
  242. else:
  243. y = F.scaled_dot_product_attention(
  244. q,
  245. k,
  246. v,
  247. dropout_p=self.attn_dropout_rate if self.training else 0.0,
  248. attn_mask=mask,
  249. )
  250. # is_causal=True)
  251. y = (
  252. y.transpose(1, 2)
  253. .contiguous()
  254. .view(bsz, seqlen, self.head_dim * self.n_head)
  255. )
  256. y = self.wo(y)
  257. return y
  258. class FeedForward(nn.Module):
  259. def __init__(self, config: ModelArgs) -> None:
  260. super().__init__()
  261. self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  262. self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  263. self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
  264. self.dropout = nn.Dropout(config.dropout_rate)
  265. def forward(self, x: Tensor) -> Tensor:
  266. return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
  267. class RMSNorm(nn.Module):
  268. def __init__(self, dim: int, eps: float = 1e-5):
  269. super().__init__()
  270. self.eps = eps
  271. self.weight = nn.Parameter(torch.ones(dim))
  272. def _norm(self, x):
  273. return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
  274. def forward(self, x: Tensor) -> Tensor:
  275. output = self._norm(x.float()).type_as(x)
  276. return output * self.weight
  277. class LayerScale(nn.Module):
  278. def __init__(
  279. self,
  280. dim: int,
  281. init_values: Union[float, Tensor] = 1e-2,
  282. inplace: bool = False,
  283. ) -> None:
  284. super().__init__()
  285. self.inplace = inplace
  286. self.gamma = nn.Parameter(init_values * torch.ones(dim))
  287. def forward(self, x: Tensor) -> Tensor:
  288. return x.mul_(self.gamma) if self.inplace else x * self.gamma
  289. class WindowLimitedTransformer(Transformer):
  290. """
  291. Transformer with window limited attention, causal.
  292. """
  293. def __init__(
  294. self,
  295. config: ModelArgs,
  296. input_dim: int = 512,
  297. window_size: Optional[int] = None,
  298. causal: bool = True,
  299. look_ahead_conv: nn.Module = None,
  300. ):
  301. super().__init__(config)
  302. self.window_size = window_size
  303. self.causal = causal
  304. self.channels_first = config.channels_first
  305. self.look_ahead_conv = (
  306. look_ahead_conv if look_ahead_conv is not None else nn.Identity()
  307. )
  308. self.input_proj = (
  309. nn.Linear(input_dim, config.dim)
  310. if input_dim != config.dim
  311. else nn.Identity()
  312. )
  313. self.output_proj = (
  314. nn.Linear(config.dim, input_dim)
  315. if input_dim != config.dim
  316. else nn.Identity()
  317. )
  318. def make_window_limited_mask(
  319. self,
  320. max_length: int,
  321. x_lens: Optional[Tensor] = None,
  322. ) -> Tensor:
  323. """
  324. Make mask to form window limited attention.
  325. """
  326. if self.causal:
  327. mask = torch.tril(torch.ones(max_length, max_length))
  328. row_indices = torch.arange(max_length).view(-1, 1)
  329. window_size = self.window_size or max_length
  330. valid_range = (row_indices - window_size + 1).clamp(min=0)
  331. column_indices = torch.arange(max_length)
  332. mask = (column_indices >= valid_range) & mask.bool()
  333. else:
  334. raise NotImplementedError
  335. mask = mask.bool()[None, None]
  336. return mask
  337. def make_mask(
  338. self,
  339. max_length: int,
  340. x_lens: Optional[Tensor] = None,
  341. ) -> Tensor:
  342. """
  343. Make ordinary mask if window size is not specified.
  344. """
  345. if self.causal:
  346. mask = torch.tril(torch.ones(max_length, max_length))
  347. else:
  348. mask = torch.ones(max_length, max_length)
  349. mask = mask.bool()[None, None]
  350. for i, x_len in enumerate(x_lens):
  351. mask[:x_len, i] = 0
  352. mask = mask.bool()[None, None]
  353. return mask
  354. def forward(
  355. self,
  356. x: Tensor,
  357. x_lens: Optional[Tensor] = None,
  358. ) -> Tensor:
  359. if self.channels_first:
  360. x = x.transpose(1, 2)
  361. x = self.input_proj(x) # (B, T, D)
  362. x = self.look_ahead_conv(x)
  363. input_pos = torch.arange(x.shape[1], device=x.device)
  364. # construct mask to form window limited attention
  365. max_length = x.shape[1]
  366. if self.window_size is not None:
  367. mask = self.make_window_limited_mask(max_length, x_lens)
  368. else:
  369. mask = self.make_mask(max_length, x_lens)
  370. mask = mask.to(x.device)
  371. x = super().forward(x, input_pos, mask)
  372. x = self.output_proj(x) # (B, T, D)
  373. if self.channels_first:
  374. x = x.transpose(1, 2)
  375. return x
  376. def precompute_freqs_cis(
  377. seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16
  378. ) -> Tensor:
  379. freqs = 1.0 / (
  380. base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
  381. )
  382. t = torch.arange(seq_len, device=freqs.device)
  383. freqs = torch.outer(t, freqs)
  384. freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
  385. cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
  386. return cache.to(dtype=dtype)
  387. def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
  388. xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
  389. freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
  390. x_out2 = torch.stack(
  391. [
  392. xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
  393. xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
  394. ],
  395. -1,
  396. )
  397. x_out2 = x_out2.flatten(3)
  398. return x_out2.type_as(x)
  399. def init_weights(m):
  400. if isinstance(m, nn.Conv1d):
  401. nn.init.trunc_normal_(m.weight, std=0.02)
  402. nn.init.constant_(m.bias, 0)
  403. def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
  404. """Remove padding from x, handling properly zero padding. Only for 1d!"""
  405. padding_left, padding_right = paddings
  406. assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
  407. assert (padding_left + padding_right) <= x.shape[-1]
  408. end = x.shape[-1] - padding_right
  409. return x[..., padding_left:end]
  410. def get_extra_padding_for_conv1d(
  411. x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
  412. ) -> int:
  413. """See `pad_for_conv1d`."""
  414. length = x.shape[-1]
  415. n_frames = (length - kernel_size + padding_total) / stride + 1
  416. ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
  417. return ideal_length - length
  418. def pad1d(
  419. x: torch.Tensor,
  420. paddings: tp.Tuple[int, int],
  421. mode: str = "zeros",
  422. value: float = 0.0,
  423. ):
  424. """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
  425. If this is the case, we insert extra 0 padding to the right
  426. before the reflection happen.
  427. """
  428. length = x.shape[-1]
  429. padding_left, padding_right = paddings
  430. assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
  431. if mode == "reflect":
  432. max_pad = max(padding_left, padding_right)
  433. extra_pad = 0
  434. if length <= max_pad:
  435. extra_pad = max_pad - length + 1
  436. x = F.pad(x, (0, extra_pad))
  437. padded = F.pad(x, paddings, mode, value)
  438. end = padded.shape[-1] - extra_pad
  439. return padded[..., :end]
  440. else:
  441. return F.pad(x, paddings, mode, value)
  442. class CausalConvNet(nn.Module):
  443. def __init__(
  444. self,
  445. in_channels,
  446. out_channels,
  447. kernel_size,
  448. dilation=1,
  449. stride=1,
  450. groups=1,
  451. padding=None,
  452. ):
  453. super(CausalConvNet, self).__init__()
  454. self.conv = nn.Conv1d(
  455. in_channels,
  456. out_channels,
  457. kernel_size,
  458. stride=stride,
  459. dilation=dilation,
  460. groups=groups,
  461. )
  462. self.stride = stride
  463. self.kernel_size = (kernel_size - 1) * dilation + 1
  464. self.dilation = dilation
  465. self.padding = self.kernel_size - self.stride
  466. def forward(self, x):
  467. pad = self.padding
  468. extra_padding = get_extra_padding_for_conv1d(
  469. x, self.kernel_size, self.stride, pad
  470. )
  471. x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
  472. return self.conv(x).contiguous()
  473. def weight_norm(self, name="weight", dim=0):
  474. self.conv = weight_norm(self.conv, name=name, dim=dim)
  475. return self
  476. def remove_weight_norm(self):
  477. self.conv = remove_parametrizations(self.conv)
  478. return self
  479. class CausalTransConvNet(nn.Module):
  480. def __init__(
  481. self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
  482. ):
  483. super(CausalTransConvNet, self).__init__()
  484. self.conv = nn.ConvTranspose1d(
  485. in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
  486. )
  487. self.stride = stride
  488. self.kernel_size = kernel_size
  489. def forward(self, x):
  490. x = self.conv(x)
  491. pad = self.kernel_size - self.stride
  492. padding_right = math.ceil(pad)
  493. padding_left = pad - padding_right
  494. x = unpad1d(x, (padding_left, padding_right))
  495. return x.contiguous()
  496. def weight_norm(self, name="weight", dim=0):
  497. self.conv = weight_norm(self.conv, name=name, dim=dim)
  498. return self
  499. def remove_weight_norm(self):
  500. self.conv = remove_parametrizations(self.conv)
  501. return self
  502. def CausalWNConv1d(*args, **kwargs):
  503. return CausalConvNet(*args, **kwargs).weight_norm()
  504. def CausalWNConvTranspose1d(*args, **kwargs):
  505. return CausalTransConvNet(*args, **kwargs).weight_norm()
  506. class ResidualUnit(nn.Module):
  507. def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False):
  508. super().__init__()
  509. conv_class = CausalWNConv1d if causal else WNConv1d
  510. pad = ((7 - 1) * dilation) // 2
  511. self.block = nn.Sequential(
  512. Snake1d(dim),
  513. conv_class(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
  514. Snake1d(dim),
  515. conv_class(dim, dim, kernel_size=1),
  516. )
  517. self.causal = causal
  518. def forward(self, x):
  519. y = self.block(x)
  520. pad = x.shape[-1] - y.shape[-1]
  521. if pad > 0:
  522. if self.causal:
  523. x = x[..., :-pad]
  524. else:
  525. x = x[..., pad // 2 : -pad // 2]
  526. return x + y
  527. class EncoderBlock(nn.Module):
  528. def __init__(
  529. self,
  530. dim: int = 16,
  531. stride: int = 1,
  532. causal: bool = False,
  533. n_t_layer: int = 0,
  534. transformer_general_config=None,
  535. ):
  536. super().__init__()
  537. conv_class = CausalWNConv1d if causal else WNConv1d
  538. transformer_module = (
  539. nn.Identity()
  540. if n_t_layer == 0
  541. else (
  542. WindowLimitedTransformer(
  543. causal=causal,
  544. input_dim=dim,
  545. window_size=512,
  546. config=transformer_general_config(
  547. n_layer=n_t_layer,
  548. n_head=dim // 64,
  549. dim=dim,
  550. intermediate_size=dim * 3,
  551. ),
  552. )
  553. )
  554. )
  555. self.block = nn.Sequential(
  556. ResidualUnit(dim // 2, dilation=1, causal=causal),
  557. ResidualUnit(dim // 2, dilation=3, causal=causal),
  558. ResidualUnit(dim // 2, dilation=9, causal=causal),
  559. Snake1d(dim // 2),
  560. conv_class(
  561. dim // 2,
  562. dim,
  563. kernel_size=2 * stride,
  564. stride=stride,
  565. padding=math.ceil(stride / 2),
  566. ),
  567. transformer_module,
  568. )
  569. def forward(self, x):
  570. return self.block(x)
  571. class Encoder(nn.Module):
  572. def __init__(
  573. self,
  574. d_model: int = 64,
  575. strides: list = [2, 4, 8, 8],
  576. d_latent: int = 64,
  577. n_transformer_layers: list = [0, 0, 4, 4],
  578. transformer_general_config: ModelArgs = None,
  579. causal: bool = False,
  580. ):
  581. super().__init__()
  582. conv_class = CausalWNConv1d if causal else WNConv1d
  583. # Create first convolution
  584. self.block = [conv_class(1, d_model, kernel_size=7, padding=3)]
  585. # Create EncoderBlocks that double channels as they downsample by `stride`
  586. for stride, n_t_layer in zip(strides, n_transformer_layers):
  587. d_model *= 2
  588. self.block += [
  589. EncoderBlock(
  590. d_model,
  591. stride=stride,
  592. causal=causal,
  593. n_t_layer=n_t_layer,
  594. transformer_general_config=transformer_general_config,
  595. )
  596. ]
  597. # Create last convolution
  598. self.block += [
  599. Snake1d(d_model),
  600. conv_class(d_model, d_latent, kernel_size=3, padding=1),
  601. ]
  602. # Wrap black into nn.Sequential
  603. self.block = nn.Sequential(*self.block)
  604. self.enc_dim = d_model
  605. def forward(self, x):
  606. return self.block(x)
  607. class DecoderBlock(nn.Module):
  608. def __init__(
  609. self,
  610. input_dim: int = 16,
  611. output_dim: int = 8,
  612. stride: int = 1,
  613. causal: bool = False,
  614. n_t_layer: int = 0,
  615. transformer_general_config=None,
  616. ):
  617. super().__init__()
  618. conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d
  619. transformer_module = (
  620. nn.Identity()
  621. if n_t_layer == 0
  622. else (
  623. WindowLimitedTransformer(
  624. causal=causal,
  625. input_dim=input_dim,
  626. window_size=None,
  627. config=transformer_general_config(
  628. n_layer=n_t_layer,
  629. n_head=input_dim // 64,
  630. dim=input_dim,
  631. intermediate_size=input_dim * 3,
  632. ),
  633. )
  634. )
  635. )
  636. self.block = nn.Sequential(
  637. # transformer_module,
  638. Snake1d(input_dim),
  639. conv_trans_class(
  640. input_dim,
  641. output_dim,
  642. kernel_size=2 * stride,
  643. stride=stride,
  644. padding=math.ceil(stride / 2),
  645. ),
  646. ResidualUnit(output_dim, dilation=1, causal=causal),
  647. ResidualUnit(output_dim, dilation=3, causal=causal),
  648. ResidualUnit(output_dim, dilation=9, causal=causal),
  649. )
  650. def forward(self, x):
  651. return self.block(x)
  652. class Decoder(nn.Module):
  653. def __init__(
  654. self,
  655. input_channel,
  656. channels,
  657. rates,
  658. d_out: int = 1,
  659. causal: bool = False,
  660. n_transformer_layers: list = [0, 0, 0, 0],
  661. transformer_general_config=None,
  662. ):
  663. super().__init__()
  664. conv_class = CausalWNConv1d if causal else WNConv1d
  665. # Add first conv layer
  666. layers = [conv_class(input_channel, channels, kernel_size=7, padding=3)]
  667. # Add upsampling + MRF blocks
  668. for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)):
  669. input_dim = channels // 2**i
  670. output_dim = channels // 2 ** (i + 1)
  671. layers += [
  672. DecoderBlock(
  673. input_dim,
  674. output_dim,
  675. stride,
  676. causal=causal,
  677. n_t_layer=n_t_layer,
  678. transformer_general_config=transformer_general_config,
  679. )
  680. ]
  681. # Add final conv layer
  682. layers += [
  683. Snake1d(output_dim),
  684. conv_class(output_dim, d_out, kernel_size=7, padding=3),
  685. nn.Tanh(),
  686. ]
  687. self.model = nn.Sequential(*layers)
  688. def forward(self, x):
  689. return self.model(x)
  690. class DAC(BaseModel, CodecMixin):
  691. def __init__(
  692. self,
  693. encoder_dim: int = 64,
  694. encoder_rates: List[int] = [2, 4, 8, 8],
  695. latent_dim: int = None,
  696. decoder_dim: int = 1536,
  697. decoder_rates: List[int] = [8, 8, 4, 2],
  698. quantizer: torch.nn.Module = None,
  699. sample_rate: int = 44100,
  700. causal: bool = True,
  701. encoder_transformer_layers: List[int] = [0, 0, 0, 0],
  702. decoder_transformer_layers: List[int] = [0, 0, 0, 0],
  703. transformer_general_config=None,
  704. ):
  705. super().__init__()
  706. self.encoder_dim = encoder_dim
  707. self.encoder_rates = encoder_rates
  708. self.decoder_dim = decoder_dim
  709. self.decoder_rates = decoder_rates
  710. self.sample_rate = sample_rate
  711. if latent_dim is None:
  712. latent_dim = encoder_dim * (2 ** len(encoder_rates))
  713. self.latent_dim = latent_dim
  714. self.hop_length = np.prod(encoder_rates)
  715. self.encoder = Encoder(
  716. encoder_dim,
  717. encoder_rates,
  718. latent_dim,
  719. causal=causal,
  720. n_transformer_layers=encoder_transformer_layers,
  721. transformer_general_config=transformer_general_config,
  722. )
  723. self.quantizer = quantizer
  724. self.decoder = Decoder(
  725. latent_dim,
  726. decoder_dim,
  727. decoder_rates,
  728. causal=causal,
  729. n_transformer_layers=decoder_transformer_layers,
  730. transformer_general_config=transformer_general_config,
  731. )
  732. self.sample_rate = sample_rate
  733. self.apply(init_weights)
  734. self.delay = self.get_delay()
  735. self.frame_length = self.hop_length * 4
  736. def preprocess(self, audio_data, sample_rate):
  737. if sample_rate is None:
  738. sample_rate = self.sample_rate
  739. assert sample_rate == self.sample_rate
  740. length = audio_data.shape[-1]
  741. right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
  742. audio_data = nn.functional.pad(audio_data, (0, right_pad))
  743. return audio_data
  744. def encode(
  745. self,
  746. audio_data: torch.Tensor,
  747. audio_lengths: torch.Tensor = None,
  748. n_quantizers: int = None,
  749. **kwargs,
  750. ):
  751. """Encode given audio data and return quantized latent codes
  752. Parameters
  753. ----------
  754. audio_data : Tensor[B x T]
  755. Audio data to encode
  756. n_quantizers : int, optional
  757. Number of quantizers to use, by default None
  758. If None, all quantizers are used.
  759. Returns
  760. -------
  761. dict
  762. A dictionary with the following keys:
  763. "z" : Tensor[B x D x T]
  764. Quantized continuous representation of input
  765. "codes" : Tensor[B x N x T]
  766. Codebook indices for each codebook
  767. (quantized discrete representation of input)
  768. "latents" : Tensor[B x N*D x T]
  769. Projected latents (continuous representation of input before quantization)
  770. "vq/commitment_loss" : Tensor[1]
  771. Commitment loss to train encoder to predict vectors closer to codebook
  772. entries
  773. "vq/codebook_loss" : Tensor[1]
  774. Codebook loss to update the codebook
  775. "length" : int
  776. Number of samples in input audio
  777. """
  778. # pad to multiple of self.frame_length
  779. if audio_data.ndim == 2:
  780. audio_data = audio_data.unsqueeze(1)
  781. # print(audio_data.shape)
  782. length = audio_data.shape[-1]
  783. right_pad = math.ceil(length / self.frame_length) * self.frame_length - length
  784. audio_data = nn.functional.pad(audio_data, (0, right_pad))
  785. if audio_lengths is None:
  786. audio_lengths = torch.LongTensor([length + right_pad]).to(audio_data.device)
  787. z = self.encoder(audio_data)
  788. vq_results = self.quantizer(z, n_quantizers, **kwargs)
  789. indices = vq_results.codes
  790. indices_lens = torch.ceil(audio_lengths / self.frame_length).long()
  791. return indices, indices_lens
  792. def decode(self, indices: torch.Tensor, feature_lengths):
  793. if indices.ndim == 2:
  794. indices = indices[None]
  795. z = self.quantizer.decode(indices)
  796. audio_lengths = feature_lengths * self.frame_length
  797. return self.decoder(z), audio_lengths
  798. def forward(
  799. self,
  800. audio_data: torch.Tensor,
  801. template: torch.Tensor = None,
  802. mask: torch.Tensor = None,
  803. sample_rate: int = None,
  804. n_quantizers: int = None,
  805. **kwargs,
  806. ):
  807. """Model forward pass
  808. Parameters
  809. ----------
  810. audio_data : Tensor[B x 1 x T]
  811. Audio data to encode
  812. sample_rate : int, optional
  813. Sample rate of audio data in Hz, by default None
  814. If None, defaults to `self.sample_rate`
  815. n_quantizers : int, optional
  816. Number of quantizers to use, by default None.
  817. If None, all quantizers are used.
  818. Returns
  819. -------
  820. dict
  821. A dictionary with the following keys:
  822. "z" : Tensor[B x D x T]
  823. Quantized continuous representation of input
  824. "codes" : Tensor[B x N x T]
  825. Codebook indices for each codebook
  826. (quantized discrete representation of input)
  827. "latents" : Tensor[B x N*D x T]
  828. Projected latents (continuous representation of input before quantization)
  829. "vq/commitment_loss" : Tensor[1]
  830. Commitment loss to train encoder to predict vectors closer to codebook
  831. entries
  832. "vq/codebook_loss" : Tensor[1]
  833. Codebook loss to update the codebook
  834. "length" : int
  835. Number of samples in input audio
  836. "audio" : Tensor[B x 1 x length]
  837. Decoded audio data.
  838. """
  839. length = audio_data.shape[-1]
  840. audio_data = self.preprocess(audio_data, sample_rate)
  841. vq_results = self.encode(audio_data, n_quantizers, **kwargs)
  842. z = vq_results[0] if isinstance(vq_results, tuple) else vq_results.z
  843. x = self.decode(z)
  844. return x[..., :length], vq_results