modded_dac.py 34 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045
  1. import math
  2. import typing as tp
  3. from dataclasses import dataclass
  4. from typing import List, Optional, Union
  5. import numpy as np
  6. import torch
  7. from audiotools import AudioSignal
  8. from audiotools.ml import BaseModel
  9. from dac.model.base import CodecMixin
  10. from dac.nn.layers import Snake1d, WNConv1d, WNConvTranspose1d
  11. from torch import Tensor, nn
  12. from torch.nn import functional as F
  13. from torch.nn.utils.parametrizations import weight_norm
  14. from torch.nn.utils.parametrize import remove_parametrizations
  15. @dataclass
  16. class VQResult:
  17. z: torch.Tensor
  18. codes: torch.Tensor
  19. latents: torch.Tensor
  20. codebook_loss: torch.Tensor
  21. commitment_loss: torch.Tensor
  22. semantic_distill_z: torch.Tensor | None = None
  23. def find_multiple(n: int, k: int) -> int:
  24. if n % k == 0:
  25. return n
  26. return n + k - (n % k)
  27. @dataclass
  28. class ModelArgs:
  29. block_size: int = 2048
  30. n_layer: int = 8
  31. n_head: int = 8
  32. dim: int = 512
  33. intermediate_size: int = 1536
  34. n_local_heads: int = -1
  35. head_dim: int = 64
  36. rope_base: float = 10000
  37. norm_eps: float = 1e-5
  38. dropout_rate: float = 0.1
  39. attn_dropout_rate: float = 0.1
  40. channels_first: bool = True # to be compatible with conv1d input/output
  41. pos_embed_type: str = "rope" # can be "rope" or "conformer"
  42. max_relative_position: int = 128 # for conformer-style relative position embedding
  43. window_size: int = 512 # for window limited attention
  44. def __post_init__(self):
  45. if self.n_local_heads == -1:
  46. self.n_local_heads = self.n_head
  47. if self.intermediate_size is None:
  48. hidden_dim = 4 * self.dim
  49. n_hidden = int(2 * hidden_dim / 3)
  50. self.intermediate_size = find_multiple(n_hidden, 256)
  51. assert self.pos_embed_type in [
  52. "rope",
  53. "conformer",
  54. ], "pos_embed_type must be either 'rope' or 'conformer'"
  55. class KVCache(nn.Module):
  56. def __init__(
  57. self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16
  58. ):
  59. super().__init__()
  60. cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
  61. self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
  62. self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
  63. def update(self, input_pos, k_val, v_val):
  64. # input_pos: [S], k_val: [B, H, S, D]
  65. assert input_pos.shape[0] == k_val.shape[2]
  66. k_out = self.k_cache
  67. v_out = self.v_cache
  68. k_out[:, :, input_pos] = k_val
  69. v_out[:, :, input_pos] = v_val
  70. return (
  71. k_out[:, :, : input_pos.max() + 1, :],
  72. v_out[:, :, : input_pos.max() + 1, :],
  73. )
  74. def clear_cache(self, prompt_len):
  75. self.k_cache[:, :, prompt_len:, :] = torch.zeros_like(
  76. self.k_cache[:, :, prompt_len:, :]
  77. )
  78. self.v_cache[:, :, prompt_len:, :] = torch.zeros_like(
  79. self.v_cache[:, :, prompt_len:, :]
  80. )
  81. class Transformer(nn.Module):
  82. def __init__(self, config: ModelArgs) -> None:
  83. super().__init__()
  84. self.config = config
  85. self.layers = nn.ModuleList(
  86. TransformerBlock(config) for _ in range(config.n_layer)
  87. )
  88. self.norm = RMSNorm(config.dim, eps=config.norm_eps)
  89. # Only compute RoPE frequencies if using RoPE
  90. if config.pos_embed_type == "rope":
  91. freqs_cis = precompute_freqs_cis(
  92. 327680, self.config.head_dim, self.config.rope_base
  93. )
  94. self.register_buffer("freqs_cis", freqs_cis, persistent=False)
  95. else:
  96. self.register_buffer("freqs_cis", None)
  97. causal_mask = torch.tril(torch.ones(32768, 32768, dtype=torch.bool))
  98. self.register_buffer("causal_mask", causal_mask, persistent=False)
  99. self.max_batch_size = -1
  100. self.max_seq_length = -1
  101. self.use_kv_cache = False
  102. def setup_caches(self, max_batch_size, max_seq_length):
  103. """
  104. This method will only be called during inference when using KV cache.
  105. """
  106. head_dim = self.config.dim // self.config.n_head
  107. max_seq_length = find_multiple(max_seq_length, 8)
  108. self.max_seq_length = max_seq_length
  109. self.max_batch_size = max_batch_size
  110. dtype = self.norm.weight.dtype
  111. device = self.norm.weight.device
  112. for b in self.layers:
  113. b.attention.kv_cache = KVCache(
  114. max_batch_size,
  115. max_seq_length,
  116. self.config.n_local_heads,
  117. head_dim,
  118. dtype,
  119. ).to(device)
  120. self.use_kv_cache = True
  121. def forward(
  122. self,
  123. x: Tensor,
  124. input_pos: Optional[Tensor] = None,
  125. mask: Optional[Tensor] = None,
  126. ) -> Tensor:
  127. if self.config.pos_embed_type == "rope":
  128. assert (
  129. self.freqs_cis is not None
  130. ), "RoPE frequencies must be initialized for RoPE positional embedding"
  131. # print("MAX", input_pos.max())
  132. freqs_cis = self.freqs_cis[input_pos]
  133. else:
  134. freqs_cis = None
  135. if mask is None: # in case of non-causal model
  136. if not self.training and self.use_kv_cache:
  137. mask = self.causal_mask[None, None, input_pos]
  138. mask = mask[..., : input_pos.max() + 1]
  139. else:
  140. mask = self.causal_mask[None, None, input_pos]
  141. mask = mask[..., input_pos]
  142. for i, layer in enumerate(self.layers):
  143. x = layer(x, input_pos, freqs_cis, mask)
  144. x = self.norm(x)
  145. return x
  146. class TransformerBlock(nn.Module):
  147. def __init__(self, config: ModelArgs) -> None:
  148. super().__init__()
  149. self.attention = Attention(config)
  150. self.feed_forward = FeedForward(config)
  151. self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
  152. self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
  153. self.attention_layer_scale = LayerScale(config.dim, inplace=True)
  154. self.ffn_layer_scale = LayerScale(config.dim, inplace=True)
  155. def forward(
  156. self,
  157. x: Tensor,
  158. input_pos: Tensor,
  159. freqs_cis: Tensor,
  160. mask: Tensor,
  161. ) -> Tensor:
  162. h = x + self.attention_layer_scale(
  163. self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
  164. )
  165. out = h + self.ffn_layer_scale(self.feed_forward(self.ffn_norm(h)))
  166. return out
  167. class Attention(nn.Module):
  168. def __init__(self, config: ModelArgs):
  169. super().__init__()
  170. assert config.dim % config.n_head == 0
  171. total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
  172. # key, query, value projections for all heads, but in a batch
  173. self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
  174. self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
  175. self.kv_cache = None
  176. self.n_head = config.n_head
  177. self.head_dim = config.head_dim
  178. self.n_local_heads = config.n_local_heads
  179. self.dim = config.dim
  180. self.attn_dropout_rate = config.attn_dropout_rate
  181. self.pos_embed_type = config.pos_embed_type
  182. # Add relative position embedding for conformer-style
  183. if self.pos_embed_type == "conformer":
  184. self.max_relative_position = config.max_relative_position
  185. num_pos_embeddings = 2 * config.max_relative_position + 1
  186. self.rel_pos_embeddings = nn.Parameter(
  187. torch.zeros(num_pos_embeddings, self.head_dim)
  188. )
  189. nn.init.normal_(self.rel_pos_embeddings, mean=0.0, std=0.02)
  190. def _compute_conformer_pos_scores(self, q: Tensor, seqlen: int) -> Tensor:
  191. # q: [B, H, S, D]
  192. # Returns: [B, H, S, S]
  193. positions = torch.arange(seqlen, device=q.device)
  194. relative_positions = positions.unsqueeze(1) - positions.unsqueeze(0) # [S, S]
  195. relative_positions = torch.clamp(
  196. relative_positions + self.max_relative_position,
  197. 0,
  198. 2 * self.max_relative_position,
  199. )
  200. rel_embeddings = self.rel_pos_embeddings[relative_positions] # [S, S, D]
  201. # Compute attention scores with relative position embeddings
  202. q = q.transpose(1, 2) # [B, S, H, D]
  203. rel_logits = torch.matmul(q, rel_embeddings.transpose(-2, -1)) # [B, S, H, S]
  204. rel_logits = rel_logits.transpose(1, 2) # [B, H, S, S]
  205. return rel_logits
  206. def forward(
  207. self,
  208. x: Tensor,
  209. freqs_cis: Tensor,
  210. mask: Tensor,
  211. input_pos: Optional[Tensor] = None,
  212. ) -> Tensor:
  213. bsz, seqlen, _ = x.shape
  214. kv_size = self.n_local_heads * self.head_dim
  215. q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
  216. context_seqlen = seqlen
  217. q = q.view(bsz, seqlen, self.n_head, self.head_dim)
  218. k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
  219. v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
  220. if self.pos_embed_type == "rope":
  221. q = apply_rotary_emb(q, freqs_cis)
  222. k = apply_rotary_emb(k, freqs_cis)
  223. q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
  224. if self.kv_cache is not None:
  225. k, v = self.kv_cache.update(input_pos, k, v)
  226. k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  227. v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  228. if self.pos_embed_type == "conformer":
  229. # Compute attention scores
  230. scale = 1.0 / math.sqrt(self.head_dim)
  231. scores = torch.matmul(q, k.transpose(-2, -1)) * scale
  232. # Add relative position embeddings for conformer-style
  233. rel_scores = self._compute_conformer_pos_scores(q, seqlen)
  234. scores = scores + rel_scores
  235. # Apply attention
  236. if mask is not None:
  237. scores = scores.masked_fill(~mask, float("-inf"))
  238. attn = F.softmax(scores, dim=-1)
  239. if self.attn_dropout_rate > 0 and self.training:
  240. attn = F.dropout(attn, p=self.attn_dropout_rate)
  241. y = torch.matmul(attn, v)
  242. else:
  243. y = F.scaled_dot_product_attention(
  244. q,
  245. k,
  246. v,
  247. dropout_p=self.attn_dropout_rate if self.training else 0.0,
  248. attn_mask=mask,
  249. )
  250. # is_causal=True)
  251. y = (
  252. y.transpose(1, 2)
  253. .contiguous()
  254. .view(bsz, seqlen, self.head_dim * self.n_head)
  255. )
  256. y = self.wo(y)
  257. return y
  258. class FeedForward(nn.Module):
  259. def __init__(self, config: ModelArgs) -> None:
  260. super().__init__()
  261. self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  262. self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  263. self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
  264. self.dropout = nn.Dropout(config.dropout_rate)
  265. def forward(self, x: Tensor) -> Tensor:
  266. return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
  267. class RMSNorm(nn.Module):
  268. def __init__(self, dim: int, eps: float = 1e-5):
  269. super().__init__()
  270. self.eps = eps
  271. self.weight = nn.Parameter(torch.ones(dim))
  272. def _norm(self, x):
  273. return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
  274. def forward(self, x: Tensor) -> Tensor:
  275. output = self._norm(x.float()).type_as(x)
  276. return output * self.weight
  277. class LayerScale(nn.Module):
  278. def __init__(
  279. self,
  280. dim: int,
  281. init_values: Union[float, Tensor] = 1e-2,
  282. inplace: bool = False,
  283. ) -> None:
  284. super().__init__()
  285. self.inplace = inplace
  286. self.gamma = nn.Parameter(init_values * torch.ones(dim))
  287. def forward(self, x: Tensor) -> Tensor:
  288. return x.mul_(self.gamma) if self.inplace else x * self.gamma
  289. class WindowLimitedTransformer(Transformer):
  290. """
  291. Transformer with window limited attention, causal.
  292. """
  293. def __init__(
  294. self,
  295. config: ModelArgs,
  296. input_dim: int = 512,
  297. window_size: Optional[int] = None,
  298. causal: bool = True,
  299. look_ahead_conv: nn.Module = None,
  300. ):
  301. super().__init__(config)
  302. self.window_size = window_size
  303. self.causal = causal
  304. self.channels_first = config.channels_first
  305. self.look_ahead_conv = (
  306. look_ahead_conv if look_ahead_conv is not None else nn.Identity()
  307. )
  308. self.input_proj = (
  309. nn.Linear(input_dim, config.dim)
  310. if input_dim != config.dim
  311. else nn.Identity()
  312. )
  313. self.output_proj = (
  314. nn.Linear(config.dim, input_dim)
  315. if input_dim != config.dim
  316. else nn.Identity()
  317. )
  318. def make_window_limited_mask(
  319. self,
  320. max_length: int,
  321. x_lens: Optional[Tensor] = None,
  322. ) -> Tensor:
  323. """
  324. Make mask to form window limited attention.
  325. """
  326. if self.causal:
  327. mask = torch.tril(torch.ones(max_length, max_length))
  328. row_indices = torch.arange(max_length).view(-1, 1)
  329. window_size = self.window_size or max_length
  330. valid_range = (row_indices - window_size + 1).clamp(min=0)
  331. column_indices = torch.arange(max_length)
  332. mask = (column_indices >= valid_range) & mask.bool()
  333. else:
  334. raise NotImplementedError
  335. mask = mask.bool()[None, None]
  336. return mask
  337. def make_mask(
  338. self,
  339. max_length: int,
  340. x_lens: Optional[Tensor] = None,
  341. ) -> Tensor:
  342. """
  343. Make ordinary mask if window size is not specified.
  344. """
  345. if self.causal:
  346. mask = torch.tril(torch.ones(max_length, max_length))
  347. else:
  348. mask = torch.ones(max_length, max_length)
  349. mask = mask.bool()[None, None]
  350. for i, x_len in enumerate(x_lens):
  351. mask[:x_len, i] = 0
  352. mask = mask.bool()[None, None]
  353. return mask
  354. def forward(
  355. self,
  356. x: Tensor,
  357. x_lens: Optional[Tensor] = None,
  358. ) -> Tensor:
  359. if self.channels_first:
  360. x = x.transpose(1, 2)
  361. x = self.input_proj(x) # (B, T, D)
  362. x = self.look_ahead_conv(x)
  363. input_pos = torch.arange(x.shape[1], device=x.device)
  364. # construct mask to form window limited attention
  365. max_length = x.shape[1]
  366. if self.window_size is not None:
  367. mask = self.make_window_limited_mask(max_length, x_lens)
  368. else:
  369. mask = self.make_mask(max_length, x_lens)
  370. mask = mask.to(x.device)
  371. x = super().forward(x, input_pos, mask)
  372. x = self.output_proj(x) # (B, T, D)
  373. if self.channels_first:
  374. x = x.transpose(1, 2)
  375. return x
  376. def precompute_freqs_cis(
  377. seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16
  378. ) -> Tensor:
  379. freqs = 1.0 / (
  380. base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
  381. )
  382. t = torch.arange(seq_len, device=freqs.device)
  383. freqs = torch.outer(t, freqs)
  384. freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
  385. cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
  386. return cache.to(dtype=dtype)
  387. def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
  388. xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
  389. freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
  390. x_out2 = torch.stack(
  391. [
  392. xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
  393. xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
  394. ],
  395. -1,
  396. )
  397. x_out2 = x_out2.flatten(3)
  398. return x_out2.type_as(x)
  399. def init_weights(m):
  400. if isinstance(m, nn.Conv1d):
  401. nn.init.trunc_normal_(m.weight, std=0.02)
  402. nn.init.constant_(m.bias, 0)
  403. def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
  404. """Remove padding from x, handling properly zero padding. Only for 1d!"""
  405. padding_left, padding_right = paddings
  406. assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
  407. assert (padding_left + padding_right) <= x.shape[-1]
  408. end = x.shape[-1] - padding_right
  409. return x[..., padding_left:end]
  410. def get_extra_padding_for_conv1d(
  411. x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
  412. ) -> int:
  413. """See `pad_for_conv1d`."""
  414. length = x.shape[-1]
  415. n_frames = (length - kernel_size + padding_total) / stride + 1
  416. ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
  417. return ideal_length - length
  418. def pad1d(
  419. x: torch.Tensor,
  420. paddings: tp.Tuple[int, int],
  421. mode: str = "zeros",
  422. value: float = 0.0,
  423. ):
  424. """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
  425. If this is the case, we insert extra 0 padding to the right
  426. before the reflection happen.
  427. """
  428. length = x.shape[-1]
  429. padding_left, padding_right = paddings
  430. assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
  431. if mode == "reflect":
  432. max_pad = max(padding_left, padding_right)
  433. extra_pad = 0
  434. if length <= max_pad:
  435. extra_pad = max_pad - length + 1
  436. x = F.pad(x, (0, extra_pad))
  437. padded = F.pad(x, paddings, mode, value)
  438. end = padded.shape[-1] - extra_pad
  439. return padded[..., :end]
  440. else:
  441. return F.pad(x, paddings, mode, value)
  442. class CausalConvNet(nn.Module):
  443. def __init__(
  444. self,
  445. in_channels,
  446. out_channels,
  447. kernel_size,
  448. dilation=1,
  449. stride=1,
  450. groups=1,
  451. padding=None,
  452. ):
  453. super(CausalConvNet, self).__init__()
  454. self.conv = nn.Conv1d(
  455. in_channels,
  456. out_channels,
  457. kernel_size,
  458. stride=stride,
  459. dilation=dilation,
  460. groups=groups,
  461. )
  462. self.stride = stride
  463. self.kernel_size = (kernel_size - 1) * dilation + 1
  464. self.dilation = dilation
  465. self.padding = self.kernel_size - self.stride
  466. def forward(self, x):
  467. pad = self.padding
  468. extra_padding = get_extra_padding_for_conv1d(
  469. x, self.kernel_size, self.stride, pad
  470. )
  471. x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
  472. return self.conv(x).contiguous()
  473. def weight_norm(self, name="weight", dim=0):
  474. self.conv = weight_norm(self.conv, name=name, dim=dim)
  475. return self
  476. def remove_weight_norm(self):
  477. self.conv = remove_parametrizations(self.conv)
  478. return self
  479. class CausalTransConvNet(nn.Module):
  480. def __init__(
  481. self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
  482. ):
  483. super(CausalTransConvNet, self).__init__()
  484. self.conv = nn.ConvTranspose1d(
  485. in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
  486. )
  487. self.stride = stride
  488. self.kernel_size = kernel_size
  489. def forward(self, x):
  490. x = self.conv(x)
  491. pad = self.kernel_size - self.stride
  492. padding_right = math.ceil(pad)
  493. padding_left = pad - padding_right
  494. x = unpad1d(x, (padding_left, padding_right))
  495. return x.contiguous()
  496. def weight_norm(self, name="weight", dim=0):
  497. self.conv = weight_norm(self.conv, name=name, dim=dim)
  498. return self
  499. def remove_weight_norm(self):
  500. self.conv = remove_parametrizations(self.conv)
  501. return self
  502. def CausalWNConv1d(*args, **kwargs):
  503. return CausalConvNet(*args, **kwargs).weight_norm()
  504. def CausalWNConvTranspose1d(*args, **kwargs):
  505. return CausalTransConvNet(*args, **kwargs).weight_norm()
  506. class ResidualUnit(nn.Module):
  507. def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False):
  508. super().__init__()
  509. conv_class = CausalWNConv1d if causal else WNConv1d
  510. pad = ((7 - 1) * dilation) // 2
  511. self.block = nn.Sequential(
  512. Snake1d(dim),
  513. conv_class(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
  514. Snake1d(dim),
  515. conv_class(dim, dim, kernel_size=1),
  516. )
  517. self.causal = causal
  518. def forward(self, x):
  519. y = self.block(x)
  520. pad = x.shape[-1] - y.shape[-1]
  521. if pad > 0:
  522. if self.causal:
  523. x = x[..., :-pad]
  524. else:
  525. x = x[..., pad // 2 : -pad // 2]
  526. return x + y
  527. class EncoderBlock(nn.Module):
  528. def __init__(
  529. self,
  530. dim: int = 16,
  531. stride: int = 1,
  532. causal: bool = False,
  533. n_t_layer: int = 0,
  534. transformer_general_config=None,
  535. ):
  536. super().__init__()
  537. conv_class = CausalWNConv1d if causal else WNConv1d
  538. transformer_module = (
  539. nn.Identity()
  540. if n_t_layer == 0
  541. else (
  542. WindowLimitedTransformer(
  543. causal=causal,
  544. input_dim=dim,
  545. window_size=getattr(transformer_general_config, "window_size", 512),
  546. config=transformer_general_config(
  547. n_layer=n_t_layer,
  548. n_head=dim // 64,
  549. dim=dim,
  550. intermediate_size=dim * 3,
  551. ),
  552. )
  553. )
  554. )
  555. self.block = nn.Sequential(
  556. ResidualUnit(dim // 2, dilation=1, causal=causal),
  557. ResidualUnit(dim // 2, dilation=3, causal=causal),
  558. ResidualUnit(dim // 2, dilation=9, causal=causal),
  559. Snake1d(dim // 2),
  560. conv_class(
  561. dim // 2,
  562. dim,
  563. kernel_size=2 * stride,
  564. stride=stride,
  565. padding=math.ceil(stride / 2),
  566. ),
  567. transformer_module,
  568. )
  569. def forward(self, x):
  570. return self.block(x)
  571. class Encoder(nn.Module):
  572. def __init__(
  573. self,
  574. d_model: int = 64,
  575. strides: list = [2, 4, 8, 8],
  576. d_latent: int = 64,
  577. n_transformer_layers: list = [0, 0, 4, 4],
  578. transformer_general_config: ModelArgs = None,
  579. causal: bool = False,
  580. ):
  581. super().__init__()
  582. conv_class = CausalWNConv1d if causal else WNConv1d
  583. # Create first convolution
  584. self.block = [conv_class(1, d_model, kernel_size=7, padding=3)]
  585. # Create EncoderBlocks that double channels as they downsample by `stride`
  586. for stride, n_t_layer in zip(strides, n_transformer_layers):
  587. d_model *= 2
  588. self.block += [
  589. EncoderBlock(
  590. d_model,
  591. stride=stride,
  592. causal=causal,
  593. n_t_layer=n_t_layer,
  594. transformer_general_config=transformer_general_config,
  595. )
  596. ]
  597. # Create last convolution
  598. self.block += [
  599. Snake1d(d_model),
  600. conv_class(d_model, d_latent, kernel_size=3, padding=1),
  601. ]
  602. # Wrap black into nn.Sequential
  603. self.block = nn.Sequential(*self.block)
  604. self.enc_dim = d_model
  605. def forward(self, x):
  606. return self.block(x)
  607. class DecoderBlock(nn.Module):
  608. def __init__(
  609. self,
  610. input_dim: int = 16,
  611. output_dim: int = 8,
  612. stride: int = 1,
  613. causal: bool = False,
  614. n_t_layer: int = 0,
  615. transformer_general_config=None,
  616. ):
  617. super().__init__()
  618. conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d
  619. transformer_module = (
  620. nn.Identity()
  621. if n_t_layer == 0
  622. else (
  623. WindowLimitedTransformer(
  624. causal=causal,
  625. input_dim=input_dim,
  626. window_size=None,
  627. config=transformer_general_config(
  628. n_layer=n_t_layer,
  629. n_head=input_dim // 64,
  630. dim=input_dim,
  631. intermediate_size=input_dim * 3,
  632. ),
  633. )
  634. )
  635. )
  636. self.block = nn.Sequential(
  637. # transformer_module,
  638. Snake1d(input_dim),
  639. conv_trans_class(
  640. input_dim,
  641. output_dim,
  642. kernel_size=2 * stride,
  643. stride=stride,
  644. padding=math.ceil(stride / 2),
  645. ),
  646. ResidualUnit(output_dim, dilation=1, causal=causal),
  647. ResidualUnit(output_dim, dilation=3, causal=causal),
  648. ResidualUnit(output_dim, dilation=9, causal=causal),
  649. )
  650. def forward(self, x):
  651. return self.block(x)
  652. class Decoder(nn.Module):
  653. def __init__(
  654. self,
  655. input_channel,
  656. channels,
  657. rates,
  658. d_out: int = 1,
  659. causal: bool = False,
  660. n_transformer_layers: list = [0, 0, 0, 0],
  661. transformer_general_config=None,
  662. ):
  663. super().__init__()
  664. conv_class = CausalWNConv1d if causal else WNConv1d
  665. # Add first conv layer
  666. layers = [conv_class(input_channel, channels, kernel_size=7, padding=3)]
  667. # Add upsampling + MRF blocks
  668. for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)):
  669. input_dim = channels // 2**i
  670. output_dim = channels // 2 ** (i + 1)
  671. layers += [
  672. DecoderBlock(
  673. input_dim,
  674. output_dim,
  675. stride,
  676. causal=causal,
  677. n_t_layer=n_t_layer,
  678. transformer_general_config=transformer_general_config,
  679. )
  680. ]
  681. # Add final conv layer
  682. layers += [
  683. Snake1d(output_dim),
  684. conv_class(output_dim, d_out, kernel_size=7, padding=3),
  685. nn.Tanh(),
  686. ]
  687. self.model = nn.Sequential(*layers)
  688. def forward(self, x):
  689. return self.model(x)
  690. class DAC(BaseModel, CodecMixin):
  691. def __init__(
  692. self,
  693. encoder_dim: int = 64,
  694. encoder_rates: List[int] = [2, 4, 8, 8],
  695. latent_dim: int = None,
  696. decoder_dim: int = 1536,
  697. decoder_rates: List[int] = [8, 8, 4, 2],
  698. quantizer: torch.nn.Module = None,
  699. sample_rate: int = 44100,
  700. causal: bool = True,
  701. encoder_transformer_layers: List[int] = [0, 0, 0, 0],
  702. decoder_transformer_layers: List[int] = [0, 0, 0, 0],
  703. overwrite_decoder: torch.nn.Module = None,
  704. transformer_general_config=None,
  705. ):
  706. super().__init__()
  707. self.encoder_dim = encoder_dim
  708. self.encoder_rates = encoder_rates
  709. self.decoder_dim = decoder_dim
  710. self.decoder_rates = decoder_rates
  711. self.sample_rate = sample_rate
  712. if latent_dim is None:
  713. latent_dim = encoder_dim * (2 ** len(encoder_rates))
  714. self.latent_dim = latent_dim
  715. self.hop_length = np.prod(encoder_rates)
  716. self.encoder = Encoder(
  717. encoder_dim,
  718. encoder_rates,
  719. latent_dim,
  720. causal=causal,
  721. n_transformer_layers=encoder_transformer_layers,
  722. transformer_general_config=transformer_general_config,
  723. )
  724. self.quantizer = quantizer
  725. if overwrite_decoder is not None:
  726. self.decoder = overwrite_decoder
  727. else:
  728. self.decoder = Decoder(
  729. latent_dim,
  730. decoder_dim,
  731. decoder_rates,
  732. causal=causal,
  733. n_transformer_layers=decoder_transformer_layers,
  734. transformer_general_config=transformer_general_config,
  735. )
  736. self.sample_rate = sample_rate
  737. self.apply(init_weights)
  738. self.delay = self.get_delay()
  739. self.frame_length = self.hop_length * 4
  740. def preprocess(self, audio_data, sample_rate):
  741. if sample_rate is None:
  742. sample_rate = self.sample_rate
  743. assert sample_rate == self.sample_rate
  744. length = audio_data.shape[-1]
  745. right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
  746. audio_data = nn.functional.pad(audio_data, (0, right_pad))
  747. return audio_data
  748. def encode(
  749. self,
  750. audio_data: torch.Tensor,
  751. audio_lengths: torch.Tensor = None,
  752. n_quantizers: int = None,
  753. **kwargs,
  754. ):
  755. """Encode given audio data and return quantized latent codes
  756. Parameters
  757. ----------
  758. audio_data : Tensor[B x T]
  759. Audio data to encode
  760. n_quantizers : int, optional
  761. Number of quantizers to use, by default None
  762. If None, all quantizers are used.
  763. Returns
  764. -------
  765. dict
  766. A dictionary with the following keys:
  767. "z" : Tensor[B x D x T]
  768. Quantized continuous representation of input
  769. "codes" : Tensor[B x N x T]
  770. Codebook indices for each codebook
  771. (quantized discrete representation of input)
  772. "latents" : Tensor[B x N*D x T]
  773. Projected latents (continuous representation of input before quantization)
  774. "vq/commitment_loss" : Tensor[1]
  775. Commitment loss to train encoder to predict vectors closer to codebook
  776. entries
  777. "vq/codebook_loss" : Tensor[1]
  778. Codebook loss to update the codebook
  779. "length" : int
  780. Number of samples in input audio
  781. """
  782. # pad to multiple of self.frame_length
  783. if audio_data.ndim == 2:
  784. audio_data = audio_data.unsqueeze(1)
  785. length = audio_data.shape[-1]
  786. right_pad = math.ceil(length / self.frame_length) * self.frame_length - length
  787. audio_data = nn.functional.pad(audio_data, (0, right_pad))
  788. if audio_lengths is None:
  789. audio_lengths = torch.LongTensor([length + right_pad]).to(audio_data.device)
  790. z = self.encoder(audio_data)
  791. vq_results = self.quantizer(z, n_quantizers, **kwargs)
  792. indices = vq_results.codes
  793. indices_lens = torch.ceil(audio_lengths / self.frame_length).long()
  794. return indices, indices_lens
  795. def from_indices(self, indices: torch.Tensor):
  796. z = self.quantizer.decode(indices)
  797. return self.decoder(z)
  798. def decode(self, z: torch.Tensor):
  799. """Decode given latent codes and return audio data
  800. Parameters
  801. ----------
  802. z : Tensor[B x D x T]
  803. Quantized continuous representation of input
  804. length : int, optional
  805. Number of samples in output audio, by default None
  806. Returns
  807. -------
  808. dict
  809. A dictionary with the following keys:
  810. "audio" : Tensor[B x 1 x length]
  811. Decoded audio data.
  812. """
  813. return self.decoder(z)
  814. def forward(
  815. self,
  816. audio_data: torch.Tensor,
  817. template: torch.Tensor = None,
  818. mask: torch.Tensor = None,
  819. sample_rate: int = None,
  820. n_quantizers: int = None,
  821. **kwargs,
  822. ):
  823. """Model forward pass
  824. Parameters
  825. ----------
  826. audio_data : Tensor[B x 1 x T]
  827. Audio data to encode
  828. sample_rate : int, optional
  829. Sample rate of audio data in Hz, by default None
  830. If None, defaults to `self.sample_rate`
  831. n_quantizers : int, optional
  832. Number of quantizers to use, by default None.
  833. If None, all quantizers are used.
  834. Returns
  835. -------
  836. dict
  837. A dictionary with the following keys:
  838. "z" : Tensor[B x D x T]
  839. Quantized continuous representation of input
  840. "codes" : Tensor[B x N x T]
  841. Codebook indices for each codebook
  842. (quantized discrete representation of input)
  843. "latents" : Tensor[B x N*D x T]
  844. Projected latents (continuous representation of input before quantization)
  845. "vq/commitment_loss" : Tensor[1]
  846. Commitment loss to train encoder to predict vectors closer to codebook
  847. entries
  848. "vq/codebook_loss" : Tensor[1]
  849. Codebook loss to update the codebook
  850. "length" : int
  851. Number of samples in input audio
  852. "audio" : Tensor[B x 1 x length]
  853. Decoded audio data.
  854. """
  855. length = audio_data.shape[-1]
  856. audio_data = self.preprocess(audio_data, sample_rate)
  857. vq_results = self.encode(audio_data, n_quantizers, **kwargs)
  858. z = vq_results[0] if isinstance(vq_results, tuple) else vq_results.z
  859. x = self.decode(z)
  860. return x[..., :length], vq_results
  861. if __name__ == "__main__":
  862. import hydra
  863. import numpy as np
  864. import soundfile as sf
  865. import torch
  866. from omegaconf import OmegaConf
  867. # 配置路径
  868. config_path = "fish_speech/configs/modded_dac_vq.yaml"
  869. checkpoint_path = "checkpoints/s2-pro/codec.pth"
  870. codes_path = "./output/codes_0.npy" # 你的 codes 文件路径
  871. output_path = "reconstructed_from_codes.wav"
  872. sample_rate = 44100 # 请确保采样率与模型训练时一致
  873. with torch.inference_mode():
  874. # 1. 初始化模型
  875. model = hydra.utils.instantiate(OmegaConf.load(config_path))
  876. new_sd = torch.load(checkpoint_path, map_location="cpu")
  877. model.load_state_dict(new_sd, strict=False)
  878. model.cuda()
  879. model.eval()
  880. # 2. 加载外部 codes (.npy)
  881. # 预期 shape 通常为 [num_codebooks, seq_len] 或 [1, num_codebooks, seq_len]
  882. codes_np = np.load(codes_path)
  883. codes_tensor = torch.from_numpy(codes_np).to(torch.long).cuda()
  884. # 如果 codes 没有 batch 维度,增加一个维度 [1, num_codebooks, seq_len]
  885. if len(codes_tensor.shape) == 2:
  886. codes_tensor = codes_tensor.unsqueeze(0)
  887. print(f"Loaded codes shape: {codes_tensor.shape}")
  888. # 3. 直接从 codes 重建音频 (Decoding)
  889. # 注意:fish_speech 的 model.from_indices 通常接受的输入是 LongTensor
  890. fake_audio = model.from_indices(codes_tensor)
  891. # 4. 后处理与保存
  892. # fake_audio 形状通常为 [B, C, T]
  893. audio_np = fake_audio.squeeze().cpu().numpy()
  894. # 如果是多声道,转置为 soundfile 要求的 (samples, channels)
  895. if len(audio_np.shape) == 2:
  896. audio_np = audio_np.T
  897. sf.write(output_path, audio_np, sample_rate)
  898. print(f"重建完成。音频已保存至: {output_path}")