modded_dac.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041
  1. import math
  2. import typing as tp
  3. from dataclasses import dataclass
  4. from typing import List, Optional, Union
  5. import numpy as np
  6. import torch
  7. from audiotools import AudioSignal
  8. from audiotools.ml import BaseModel
  9. from dac.model.base import CodecMixin
  10. from dac.nn.layers import Snake1d, WNConv1d, WNConvTranspose1d
  11. from torch import Tensor, nn
  12. from torch.nn import functional as F
  13. from torch.nn.utils.parametrizations import weight_norm
  14. from torch.nn.utils.parametrize import remove_parametrizations
  15. @dataclass
  16. class VQResult:
  17. z: torch.Tensor
  18. codes: torch.Tensor
  19. latents: torch.Tensor
  20. codebook_loss: torch.Tensor
  21. commitment_loss: torch.Tensor
  22. semantic_distill_z: torch.Tensor | None = None
  23. def find_multiple(n: int, k: int) -> int:
  24. if n % k == 0:
  25. return n
  26. return n + k - (n % k)
  27. @dataclass
  28. class ModelArgs:
  29. block_size: int = 2048
  30. n_layer: int = 8
  31. n_head: int = 8
  32. dim: int = 512
  33. intermediate_size: int = 1536
  34. n_local_heads: int = -1
  35. head_dim: int = 64
  36. rope_base: float = 10000
  37. norm_eps: float = 1e-5
  38. dropout_rate: float = 0.1
  39. attn_dropout_rate: float = 0.1
  40. channels_first: bool = True # to be compatible with conv1d input/output
  41. pos_embed_type: str = "rope" # can be "rope" or "conformer"
  42. max_relative_position: int = 128 # for conformer-style relative position embedding
  43. window_size: int = 512 # for window limited attention
  44. def __post_init__(self):
  45. if self.n_local_heads == -1:
  46. self.n_local_heads = self.n_head
  47. if self.intermediate_size is None:
  48. hidden_dim = 4 * self.dim
  49. n_hidden = int(2 * hidden_dim / 3)
  50. self.intermediate_size = find_multiple(n_hidden, 256)
  51. assert self.pos_embed_type in [
  52. "rope",
  53. "conformer",
  54. ], "pos_embed_type must be either 'rope' or 'conformer'"
  55. class KVCache(nn.Module):
  56. def __init__(
  57. self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16
  58. ):
  59. super().__init__()
  60. cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
  61. self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
  62. self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
  63. def update(self, input_pos, k_val, v_val):
  64. # input_pos: [S], k_val: [B, H, S, D]
  65. assert input_pos.shape[0] == k_val.shape[2]
  66. k_out = self.k_cache
  67. v_out = self.v_cache
  68. k_out[:, :, input_pos] = k_val
  69. v_out[:, :, input_pos] = v_val
  70. return (
  71. k_out[:, :, : input_pos.max() + 1, :],
  72. v_out[:, :, : input_pos.max() + 1, :],
  73. )
  74. def clear_cache(self, prompt_len):
  75. self.k_cache[:, :, prompt_len:, :].fill_(0)
  76. self.v_cache[:, :, prompt_len:, :].fill_(0)
  77. class Transformer(nn.Module):
  78. def __init__(self, config: ModelArgs) -> None:
  79. super().__init__()
  80. self.config = config
  81. self.layers = nn.ModuleList(
  82. TransformerBlock(config) for _ in range(config.n_layer)
  83. )
  84. self.norm = RMSNorm(config.dim, eps=config.norm_eps)
  85. # Only compute RoPE frequencies if using RoPE
  86. if config.pos_embed_type == "rope":
  87. freqs_cis = precompute_freqs_cis(
  88. 327680, self.config.head_dim, self.config.rope_base
  89. )
  90. self.register_buffer("freqs_cis", freqs_cis, persistent=False)
  91. else:
  92. self.register_buffer("freqs_cis", None)
  93. causal_mask = torch.tril(torch.ones(32768, 32768, dtype=torch.bool))
  94. self.register_buffer("causal_mask", causal_mask, persistent=False)
  95. self.max_batch_size = -1
  96. self.max_seq_length = -1
  97. self.use_kv_cache = False
  98. def setup_caches(self, max_batch_size, max_seq_length):
  99. """
  100. This method will only be called during inference when using KV cache.
  101. """
  102. head_dim = self.config.dim // self.config.n_head
  103. max_seq_length = find_multiple(max_seq_length, 8)
  104. self.max_seq_length = max_seq_length
  105. self.max_batch_size = max_batch_size
  106. dtype = self.norm.weight.dtype
  107. device = self.norm.weight.device
  108. for b in self.layers:
  109. b.attention.kv_cache = KVCache(
  110. max_batch_size,
  111. max_seq_length,
  112. self.config.n_local_heads,
  113. head_dim,
  114. dtype,
  115. ).to(device)
  116. self.use_kv_cache = True
  117. def forward(
  118. self,
  119. x: Tensor,
  120. input_pos: Optional[Tensor] = None,
  121. mask: Optional[Tensor] = None,
  122. ) -> Tensor:
  123. if self.config.pos_embed_type == "rope":
  124. assert (
  125. self.freqs_cis is not None
  126. ), "RoPE frequencies must be initialized for RoPE positional embedding"
  127. # print("MAX", input_pos.max())
  128. freqs_cis = self.freqs_cis[input_pos]
  129. else:
  130. freqs_cis = None
  131. if mask is None: # in case of non-causal model
  132. if not self.training and self.use_kv_cache:
  133. mask = self.causal_mask[None, None, input_pos]
  134. mask = mask[..., : input_pos.max() + 1]
  135. else:
  136. mask = self.causal_mask[None, None, input_pos]
  137. mask = mask[..., input_pos]
  138. for i, layer in enumerate(self.layers):
  139. x = layer(x, input_pos, freqs_cis, mask)
  140. x = self.norm(x)
  141. return x
  142. class TransformerBlock(nn.Module):
  143. def __init__(self, config: ModelArgs) -> None:
  144. super().__init__()
  145. self.attention = Attention(config)
  146. self.feed_forward = FeedForward(config)
  147. self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
  148. self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
  149. self.attention_layer_scale = LayerScale(config.dim, inplace=True)
  150. self.ffn_layer_scale = LayerScale(config.dim, inplace=True)
  151. def forward(
  152. self,
  153. x: Tensor,
  154. input_pos: Tensor,
  155. freqs_cis: Tensor,
  156. mask: Tensor,
  157. ) -> Tensor:
  158. h = x + self.attention_layer_scale(
  159. self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
  160. )
  161. out = h + self.ffn_layer_scale(self.feed_forward(self.ffn_norm(h)))
  162. return out
  163. class Attention(nn.Module):
  164. def __init__(self, config: ModelArgs):
  165. super().__init__()
  166. assert config.dim % config.n_head == 0
  167. total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
  168. # key, query, value projections for all heads, but in a batch
  169. self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
  170. self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
  171. self.kv_cache = None
  172. self.n_head = config.n_head
  173. self.head_dim = config.head_dim
  174. self.n_local_heads = config.n_local_heads
  175. self.dim = config.dim
  176. self.attn_dropout_rate = config.attn_dropout_rate
  177. self.pos_embed_type = config.pos_embed_type
  178. # Add relative position embedding for conformer-style
  179. if self.pos_embed_type == "conformer":
  180. self.max_relative_position = config.max_relative_position
  181. num_pos_embeddings = 2 * config.max_relative_position + 1
  182. self.rel_pos_embeddings = nn.Parameter(
  183. torch.zeros(num_pos_embeddings, self.head_dim)
  184. )
  185. nn.init.normal_(self.rel_pos_embeddings, mean=0.0, std=0.02)
  186. def _compute_conformer_pos_scores(self, q: Tensor, seqlen: int) -> Tensor:
  187. # q: [B, H, S, D]
  188. # Returns: [B, H, S, S]
  189. positions = torch.arange(seqlen, device=q.device)
  190. relative_positions = positions.unsqueeze(1) - positions.unsqueeze(0) # [S, S]
  191. relative_positions = torch.clamp(
  192. relative_positions + self.max_relative_position,
  193. 0,
  194. 2 * self.max_relative_position,
  195. )
  196. rel_embeddings = self.rel_pos_embeddings[relative_positions] # [S, S, D]
  197. # Compute attention scores with relative position embeddings
  198. q = q.transpose(1, 2) # [B, S, H, D]
  199. rel_logits = torch.matmul(q, rel_embeddings.transpose(-2, -1)) # [B, S, H, S]
  200. rel_logits = rel_logits.transpose(1, 2) # [B, H, S, S]
  201. return rel_logits
  202. def forward(
  203. self,
  204. x: Tensor,
  205. freqs_cis: Tensor,
  206. mask: Tensor,
  207. input_pos: Optional[Tensor] = None,
  208. ) -> Tensor:
  209. bsz, seqlen, _ = x.shape
  210. kv_size = self.n_local_heads * self.head_dim
  211. q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
  212. context_seqlen = seqlen
  213. q = q.view(bsz, seqlen, self.n_head, self.head_dim)
  214. k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
  215. v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
  216. if self.pos_embed_type == "rope":
  217. q = apply_rotary_emb(q, freqs_cis)
  218. k = apply_rotary_emb(k, freqs_cis)
  219. q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
  220. if self.kv_cache is not None:
  221. k, v = self.kv_cache.update(input_pos, k, v)
  222. k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  223. v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  224. if self.pos_embed_type == "conformer":
  225. # Compute attention scores
  226. scale = 1.0 / math.sqrt(self.head_dim)
  227. scores = torch.matmul(q, k.transpose(-2, -1)) * scale
  228. # Add relative position embeddings for conformer-style
  229. rel_scores = self._compute_conformer_pos_scores(q, seqlen)
  230. scores = scores + rel_scores
  231. # Apply attention
  232. if mask is not None:
  233. scores = scores.masked_fill(~mask, float("-inf"))
  234. attn = F.softmax(scores, dim=-1)
  235. if self.attn_dropout_rate > 0 and self.training:
  236. attn = F.dropout(attn, p=self.attn_dropout_rate)
  237. y = torch.matmul(attn, v)
  238. else:
  239. y = F.scaled_dot_product_attention(
  240. q,
  241. k,
  242. v,
  243. dropout_p=self.attn_dropout_rate if self.training else 0.0,
  244. attn_mask=mask,
  245. )
  246. # is_causal=True)
  247. y = (
  248. y.transpose(1, 2)
  249. .contiguous()
  250. .view(bsz, seqlen, self.head_dim * self.n_head)
  251. )
  252. y = self.wo(y)
  253. return y
  254. class FeedForward(nn.Module):
  255. def __init__(self, config: ModelArgs) -> None:
  256. super().__init__()
  257. self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  258. self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  259. self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
  260. self.dropout = nn.Dropout(config.dropout_rate)
  261. def forward(self, x: Tensor) -> Tensor:
  262. return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
  263. class RMSNorm(nn.Module):
  264. def __init__(self, dim: int, eps: float = 1e-5):
  265. super().__init__()
  266. self.eps = eps
  267. self.weight = nn.Parameter(torch.ones(dim))
  268. def _norm(self, x):
  269. return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
  270. def forward(self, x: Tensor) -> Tensor:
  271. output = self._norm(x.float()).type_as(x)
  272. return output * self.weight
  273. class LayerScale(nn.Module):
  274. def __init__(
  275. self,
  276. dim: int,
  277. init_values: Union[float, Tensor] = 1e-2,
  278. inplace: bool = False,
  279. ) -> None:
  280. super().__init__()
  281. self.inplace = inplace
  282. self.gamma = nn.Parameter(init_values * torch.ones(dim))
  283. def forward(self, x: Tensor) -> Tensor:
  284. return x.mul_(self.gamma) if self.inplace else x * self.gamma
  285. class WindowLimitedTransformer(Transformer):
  286. """
  287. Transformer with window limited attention, causal.
  288. """
  289. def __init__(
  290. self,
  291. config: ModelArgs,
  292. input_dim: int = 512,
  293. window_size: Optional[int] = None,
  294. causal: bool = True,
  295. look_ahead_conv: nn.Module = None,
  296. ):
  297. super().__init__(config)
  298. self.window_size = window_size
  299. self.causal = causal
  300. self.channels_first = config.channels_first
  301. self.look_ahead_conv = (
  302. look_ahead_conv if look_ahead_conv is not None else nn.Identity()
  303. )
  304. self.input_proj = (
  305. nn.Linear(input_dim, config.dim)
  306. if input_dim != config.dim
  307. else nn.Identity()
  308. )
  309. self.output_proj = (
  310. nn.Linear(config.dim, input_dim)
  311. if input_dim != config.dim
  312. else nn.Identity()
  313. )
  314. def make_window_limited_mask(
  315. self,
  316. max_length: int,
  317. x_lens: Optional[Tensor] = None,
  318. ) -> Tensor:
  319. """
  320. Make mask to form window limited attention.
  321. """
  322. if self.causal:
  323. mask = torch.tril(torch.ones(max_length, max_length))
  324. row_indices = torch.arange(max_length).view(-1, 1)
  325. window_size = self.window_size or max_length
  326. valid_range = (row_indices - window_size + 1).clamp(min=0)
  327. column_indices = torch.arange(max_length)
  328. mask = (column_indices >= valid_range) & mask.bool()
  329. else:
  330. raise NotImplementedError
  331. mask = mask.bool()[None, None]
  332. return mask
  333. def make_mask(
  334. self,
  335. max_length: int,
  336. x_lens: Optional[Tensor] = None,
  337. ) -> Tensor:
  338. """
  339. Make ordinary mask if window size is not specified.
  340. """
  341. if self.causal:
  342. mask = torch.tril(torch.ones(max_length, max_length))
  343. else:
  344. mask = torch.ones(max_length, max_length)
  345. mask = mask.bool()[None, None]
  346. for i, x_len in enumerate(x_lens):
  347. mask[:x_len, i] = 0
  348. mask = mask.bool()[None, None]
  349. return mask
  350. def forward(
  351. self,
  352. x: Tensor,
  353. x_lens: Optional[Tensor] = None,
  354. ) -> Tensor:
  355. if self.channels_first:
  356. x = x.transpose(1, 2)
  357. x = self.input_proj(x) # (B, T, D)
  358. x = self.look_ahead_conv(x)
  359. input_pos = torch.arange(x.shape[1], device=x.device)
  360. # construct mask to form window limited attention
  361. max_length = x.shape[1]
  362. if self.window_size is not None:
  363. mask = self.make_window_limited_mask(max_length, x_lens)
  364. else:
  365. mask = self.make_mask(max_length, x_lens)
  366. mask = mask.to(x.device)
  367. x = super().forward(x, input_pos, mask)
  368. x = self.output_proj(x) # (B, T, D)
  369. if self.channels_first:
  370. x = x.transpose(1, 2)
  371. return x
  372. def precompute_freqs_cis(
  373. seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16
  374. ) -> Tensor:
  375. freqs = 1.0 / (
  376. base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
  377. )
  378. t = torch.arange(seq_len, device=freqs.device)
  379. freqs = torch.outer(t, freqs)
  380. freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
  381. cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
  382. return cache.to(dtype=dtype)
  383. def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
  384. xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
  385. freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
  386. x_out2 = torch.stack(
  387. [
  388. xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
  389. xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
  390. ],
  391. -1,
  392. )
  393. x_out2 = x_out2.flatten(3)
  394. return x_out2.type_as(x)
  395. def init_weights(m):
  396. if isinstance(m, nn.Conv1d):
  397. nn.init.trunc_normal_(m.weight, std=0.02)
  398. nn.init.constant_(m.bias, 0)
  399. def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
  400. """Remove padding from x, handling properly zero padding. Only for 1d!"""
  401. padding_left, padding_right = paddings
  402. assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
  403. assert (padding_left + padding_right) <= x.shape[-1]
  404. end = x.shape[-1] - padding_right
  405. return x[..., padding_left:end]
  406. def get_extra_padding_for_conv1d(
  407. x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
  408. ) -> int:
  409. """See `pad_for_conv1d`."""
  410. length = x.shape[-1]
  411. n_frames = (length - kernel_size + padding_total) / stride + 1
  412. ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
  413. return ideal_length - length
  414. def pad1d(
  415. x: torch.Tensor,
  416. paddings: tp.Tuple[int, int],
  417. mode: str = "zeros",
  418. value: float = 0.0,
  419. ):
  420. """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
  421. If this is the case, we insert extra 0 padding to the right
  422. before the reflection happen.
  423. """
  424. length = x.shape[-1]
  425. padding_left, padding_right = paddings
  426. assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
  427. if mode == "reflect":
  428. max_pad = max(padding_left, padding_right)
  429. extra_pad = 0
  430. if length <= max_pad:
  431. extra_pad = max_pad - length + 1
  432. x = F.pad(x, (0, extra_pad))
  433. padded = F.pad(x, paddings, mode, value)
  434. end = padded.shape[-1] - extra_pad
  435. return padded[..., :end]
  436. else:
  437. return F.pad(x, paddings, mode, value)
  438. class CausalConvNet(nn.Module):
  439. def __init__(
  440. self,
  441. in_channels,
  442. out_channels,
  443. kernel_size,
  444. dilation=1,
  445. stride=1,
  446. groups=1,
  447. padding=None,
  448. ):
  449. super(CausalConvNet, self).__init__()
  450. self.conv = nn.Conv1d(
  451. in_channels,
  452. out_channels,
  453. kernel_size,
  454. stride=stride,
  455. dilation=dilation,
  456. groups=groups,
  457. )
  458. self.stride = stride
  459. self.kernel_size = (kernel_size - 1) * dilation + 1
  460. self.dilation = dilation
  461. self.padding = self.kernel_size - self.stride
  462. def forward(self, x):
  463. pad = self.padding
  464. extra_padding = get_extra_padding_for_conv1d(
  465. x, self.kernel_size, self.stride, pad
  466. )
  467. x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
  468. return self.conv(x).contiguous()
  469. def weight_norm(self, name="weight", dim=0):
  470. self.conv = weight_norm(self.conv, name=name, dim=dim)
  471. return self
  472. def remove_weight_norm(self):
  473. self.conv = remove_parametrizations(self.conv)
  474. return self
  475. class CausalTransConvNet(nn.Module):
  476. def __init__(
  477. self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
  478. ):
  479. super(CausalTransConvNet, self).__init__()
  480. self.conv = nn.ConvTranspose1d(
  481. in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
  482. )
  483. self.stride = stride
  484. self.kernel_size = kernel_size
  485. def forward(self, x):
  486. x = self.conv(x)
  487. pad = self.kernel_size - self.stride
  488. padding_right = math.ceil(pad)
  489. padding_left = pad - padding_right
  490. x = unpad1d(x, (padding_left, padding_right))
  491. return x.contiguous()
  492. def weight_norm(self, name="weight", dim=0):
  493. self.conv = weight_norm(self.conv, name=name, dim=dim)
  494. return self
  495. def remove_weight_norm(self):
  496. self.conv = remove_parametrizations(self.conv)
  497. return self
  498. def CausalWNConv1d(*args, **kwargs):
  499. return CausalConvNet(*args, **kwargs).weight_norm()
  500. def CausalWNConvTranspose1d(*args, **kwargs):
  501. return CausalTransConvNet(*args, **kwargs).weight_norm()
  502. class ResidualUnit(nn.Module):
  503. def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False):
  504. super().__init__()
  505. conv_class = CausalWNConv1d if causal else WNConv1d
  506. pad = ((7 - 1) * dilation) // 2
  507. self.block = nn.Sequential(
  508. Snake1d(dim),
  509. conv_class(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
  510. Snake1d(dim),
  511. conv_class(dim, dim, kernel_size=1),
  512. )
  513. self.causal = causal
  514. def forward(self, x):
  515. y = self.block(x)
  516. pad = x.shape[-1] - y.shape[-1]
  517. if pad > 0:
  518. if self.causal:
  519. x = x[..., :-pad]
  520. else:
  521. x = x[..., pad // 2 : -pad // 2]
  522. return x + y
  523. class EncoderBlock(nn.Module):
  524. def __init__(
  525. self,
  526. dim: int = 16,
  527. stride: int = 1,
  528. causal: bool = False,
  529. n_t_layer: int = 0,
  530. transformer_general_config=None,
  531. ):
  532. super().__init__()
  533. conv_class = CausalWNConv1d if causal else WNConv1d
  534. transformer_module = (
  535. nn.Identity()
  536. if n_t_layer == 0
  537. else (
  538. WindowLimitedTransformer(
  539. causal=causal,
  540. input_dim=dim,
  541. window_size=getattr(transformer_general_config, "window_size", 512),
  542. config=transformer_general_config(
  543. n_layer=n_t_layer,
  544. n_head=dim // 64,
  545. dim=dim,
  546. intermediate_size=dim * 3,
  547. ),
  548. )
  549. )
  550. )
  551. self.block = nn.Sequential(
  552. ResidualUnit(dim // 2, dilation=1, causal=causal),
  553. ResidualUnit(dim // 2, dilation=3, causal=causal),
  554. ResidualUnit(dim // 2, dilation=9, causal=causal),
  555. Snake1d(dim // 2),
  556. conv_class(
  557. dim // 2,
  558. dim,
  559. kernel_size=2 * stride,
  560. stride=stride,
  561. padding=math.ceil(stride / 2),
  562. ),
  563. transformer_module,
  564. )
  565. def forward(self, x):
  566. return self.block(x)
  567. class Encoder(nn.Module):
  568. def __init__(
  569. self,
  570. d_model: int = 64,
  571. strides: list = [2, 4, 8, 8],
  572. d_latent: int = 64,
  573. n_transformer_layers: list = [0, 0, 4, 4],
  574. transformer_general_config: ModelArgs = None,
  575. causal: bool = False,
  576. ):
  577. super().__init__()
  578. conv_class = CausalWNConv1d if causal else WNConv1d
  579. # Create first convolution
  580. self.block = [conv_class(1, d_model, kernel_size=7, padding=3)]
  581. # Create EncoderBlocks that double channels as they downsample by `stride`
  582. for stride, n_t_layer in zip(strides, n_transformer_layers):
  583. d_model *= 2
  584. self.block += [
  585. EncoderBlock(
  586. d_model,
  587. stride=stride,
  588. causal=causal,
  589. n_t_layer=n_t_layer,
  590. transformer_general_config=transformer_general_config,
  591. )
  592. ]
  593. # Create last convolution
  594. self.block += [
  595. Snake1d(d_model),
  596. conv_class(d_model, d_latent, kernel_size=3, padding=1),
  597. ]
  598. # Wrap black into nn.Sequential
  599. self.block = nn.Sequential(*self.block)
  600. self.enc_dim = d_model
  601. def forward(self, x):
  602. return self.block(x)
  603. class DecoderBlock(nn.Module):
  604. def __init__(
  605. self,
  606. input_dim: int = 16,
  607. output_dim: int = 8,
  608. stride: int = 1,
  609. causal: bool = False,
  610. n_t_layer: int = 0,
  611. transformer_general_config=None,
  612. ):
  613. super().__init__()
  614. conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d
  615. transformer_module = (
  616. nn.Identity()
  617. if n_t_layer == 0
  618. else (
  619. WindowLimitedTransformer(
  620. causal=causal,
  621. input_dim=input_dim,
  622. window_size=None,
  623. config=transformer_general_config(
  624. n_layer=n_t_layer,
  625. n_head=input_dim // 64,
  626. dim=input_dim,
  627. intermediate_size=input_dim * 3,
  628. ),
  629. )
  630. )
  631. )
  632. self.block = nn.Sequential(
  633. # transformer_module,
  634. Snake1d(input_dim),
  635. conv_trans_class(
  636. input_dim,
  637. output_dim,
  638. kernel_size=2 * stride,
  639. stride=stride,
  640. padding=math.ceil(stride / 2),
  641. ),
  642. ResidualUnit(output_dim, dilation=1, causal=causal),
  643. ResidualUnit(output_dim, dilation=3, causal=causal),
  644. ResidualUnit(output_dim, dilation=9, causal=causal),
  645. )
  646. def forward(self, x):
  647. return self.block(x)
  648. class Decoder(nn.Module):
  649. def __init__(
  650. self,
  651. input_channel,
  652. channels,
  653. rates,
  654. d_out: int = 1,
  655. causal: bool = False,
  656. n_transformer_layers: list = [0, 0, 0, 0],
  657. transformer_general_config=None,
  658. ):
  659. super().__init__()
  660. conv_class = CausalWNConv1d if causal else WNConv1d
  661. # Add first conv layer
  662. layers = [conv_class(input_channel, channels, kernel_size=7, padding=3)]
  663. # Add upsampling + MRF blocks
  664. for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)):
  665. input_dim = channels // 2**i
  666. output_dim = channels // 2 ** (i + 1)
  667. layers += [
  668. DecoderBlock(
  669. input_dim,
  670. output_dim,
  671. stride,
  672. causal=causal,
  673. n_t_layer=n_t_layer,
  674. transformer_general_config=transformer_general_config,
  675. )
  676. ]
  677. # Add final conv layer
  678. layers += [
  679. Snake1d(output_dim),
  680. conv_class(output_dim, d_out, kernel_size=7, padding=3),
  681. nn.Tanh(),
  682. ]
  683. self.model = nn.Sequential(*layers)
  684. def forward(self, x):
  685. return self.model(x)
  686. class DAC(BaseModel, CodecMixin):
  687. def __init__(
  688. self,
  689. encoder_dim: int = 64,
  690. encoder_rates: List[int] = [2, 4, 8, 8],
  691. latent_dim: int = None,
  692. decoder_dim: int = 1536,
  693. decoder_rates: List[int] = [8, 8, 4, 2],
  694. quantizer: torch.nn.Module = None,
  695. sample_rate: int = 44100,
  696. causal: bool = True,
  697. encoder_transformer_layers: List[int] = [0, 0, 0, 0],
  698. decoder_transformer_layers: List[int] = [0, 0, 0, 0],
  699. overwrite_decoder: torch.nn.Module = None,
  700. transformer_general_config=None,
  701. ):
  702. super().__init__()
  703. self.encoder_dim = encoder_dim
  704. self.encoder_rates = encoder_rates
  705. self.decoder_dim = decoder_dim
  706. self.decoder_rates = decoder_rates
  707. self.sample_rate = sample_rate
  708. if latent_dim is None:
  709. latent_dim = encoder_dim * (2 ** len(encoder_rates))
  710. self.latent_dim = latent_dim
  711. self.hop_length = np.prod(encoder_rates)
  712. self.encoder = Encoder(
  713. encoder_dim,
  714. encoder_rates,
  715. latent_dim,
  716. causal=causal,
  717. n_transformer_layers=encoder_transformer_layers,
  718. transformer_general_config=transformer_general_config,
  719. )
  720. self.quantizer = quantizer
  721. if overwrite_decoder is not None:
  722. self.decoder = overwrite_decoder
  723. else:
  724. self.decoder = Decoder(
  725. latent_dim,
  726. decoder_dim,
  727. decoder_rates,
  728. causal=causal,
  729. n_transformer_layers=decoder_transformer_layers,
  730. transformer_general_config=transformer_general_config,
  731. )
  732. self.sample_rate = sample_rate
  733. self.apply(init_weights)
  734. self.delay = self.get_delay()
  735. self.frame_length = self.hop_length * 4
  736. def preprocess(self, audio_data, sample_rate):
  737. if sample_rate is None:
  738. sample_rate = self.sample_rate
  739. assert sample_rate == self.sample_rate
  740. length = audio_data.shape[-1]
  741. right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
  742. audio_data = nn.functional.pad(audio_data, (0, right_pad))
  743. return audio_data
  744. def encode(
  745. self,
  746. audio_data: torch.Tensor,
  747. audio_lengths: torch.Tensor = None,
  748. n_quantizers: int = None,
  749. **kwargs,
  750. ):
  751. """Encode given audio data and return quantized latent codes
  752. Parameters
  753. ----------
  754. audio_data : Tensor[B x T]
  755. Audio data to encode
  756. n_quantizers : int, optional
  757. Number of quantizers to use, by default None
  758. If None, all quantizers are used.
  759. Returns
  760. -------
  761. dict
  762. A dictionary with the following keys:
  763. "z" : Tensor[B x D x T]
  764. Quantized continuous representation of input
  765. "codes" : Tensor[B x N x T]
  766. Codebook indices for each codebook
  767. (quantized discrete representation of input)
  768. "latents" : Tensor[B x N*D x T]
  769. Projected latents (continuous representation of input before quantization)
  770. "vq/commitment_loss" : Tensor[1]
  771. Commitment loss to train encoder to predict vectors closer to codebook
  772. entries
  773. "vq/codebook_loss" : Tensor[1]
  774. Codebook loss to update the codebook
  775. "length" : int
  776. Number of samples in input audio
  777. """
  778. # pad to multiple of self.frame_length
  779. if audio_data.ndim == 2:
  780. audio_data = audio_data.unsqueeze(1)
  781. length = audio_data.shape[-1]
  782. right_pad = math.ceil(length / self.frame_length) * self.frame_length - length
  783. audio_data = nn.functional.pad(audio_data, (0, right_pad))
  784. if audio_lengths is None:
  785. audio_lengths = torch.LongTensor([length + right_pad]).to(audio_data.device)
  786. z = self.encoder(audio_data)
  787. vq_results = self.quantizer(z, n_quantizers, **kwargs)
  788. indices = vq_results.codes
  789. indices_lens = torch.ceil(audio_lengths / self.frame_length).long()
  790. return indices, indices_lens
  791. def from_indices(self, indices: torch.Tensor):
  792. z = self.quantizer.decode(indices)
  793. return self.decoder(z)
  794. def decode(self, z: torch.Tensor):
  795. """Decode given latent codes and return audio data
  796. Parameters
  797. ----------
  798. z : Tensor[B x D x T]
  799. Quantized continuous representation of input
  800. length : int, optional
  801. Number of samples in output audio, by default None
  802. Returns
  803. -------
  804. dict
  805. A dictionary with the following keys:
  806. "audio" : Tensor[B x 1 x length]
  807. Decoded audio data.
  808. """
  809. return self.decoder(z)
  810. def forward(
  811. self,
  812. audio_data: torch.Tensor,
  813. template: torch.Tensor = None,
  814. mask: torch.Tensor = None,
  815. sample_rate: int = None,
  816. n_quantizers: int = None,
  817. **kwargs,
  818. ):
  819. """Model forward pass
  820. Parameters
  821. ----------
  822. audio_data : Tensor[B x 1 x T]
  823. Audio data to encode
  824. sample_rate : int, optional
  825. Sample rate of audio data in Hz, by default None
  826. If None, defaults to `self.sample_rate`
  827. n_quantizers : int, optional
  828. Number of quantizers to use, by default None.
  829. If None, all quantizers are used.
  830. Returns
  831. -------
  832. dict
  833. A dictionary with the following keys:
  834. "z" : Tensor[B x D x T]
  835. Quantized continuous representation of input
  836. "codes" : Tensor[B x N x T]
  837. Codebook indices for each codebook
  838. (quantized discrete representation of input)
  839. "latents" : Tensor[B x N*D x T]
  840. Projected latents (continuous representation of input before quantization)
  841. "vq/commitment_loss" : Tensor[1]
  842. Commitment loss to train encoder to predict vectors closer to codebook
  843. entries
  844. "vq/codebook_loss" : Tensor[1]
  845. Codebook loss to update the codebook
  846. "length" : int
  847. Number of samples in input audio
  848. "audio" : Tensor[B x 1 x length]
  849. Decoded audio data.
  850. """
  851. length = audio_data.shape[-1]
  852. audio_data = self.preprocess(audio_data, sample_rate)
  853. vq_results = self.encode(audio_data, n_quantizers, **kwargs)
  854. z = vq_results[0] if isinstance(vq_results, tuple) else vq_results.z
  855. x = self.decode(z)
  856. return x[..., :length], vq_results
  857. if __name__ == "__main__":
  858. import hydra
  859. import torch
  860. import numpy as np
  861. import soundfile as sf
  862. from omegaconf import OmegaConf
  863. # 配置路径
  864. config_path = "fish_speech/configs/modded_dac_vq.yaml"
  865. checkpoint_path = "checkpoints/s2-pro/codec.pth"
  866. codes_path = "./output/codes_0.npy" # 你的 codes 文件路径
  867. output_path = "reconstructed_from_codes.wav"
  868. sample_rate = 44100 # 请确保采样率与模型训练时一致
  869. with torch.inference_mode():
  870. # 1. 初始化模型
  871. model = hydra.utils.instantiate(OmegaConf.load(config_path))
  872. new_sd = torch.load(checkpoint_path, map_location="cpu")
  873. model.load_state_dict(new_sd, strict=False)
  874. model.cuda()
  875. model.eval()
  876. # 2. 加载外部 codes (.npy)
  877. # 预期 shape 通常为 [num_codebooks, seq_len] 或 [1, num_codebooks, seq_len]
  878. codes_np = np.load(codes_path)
  879. codes_tensor = torch.from_numpy(codes_np).to(torch.long).cuda()
  880. # 如果 codes 没有 batch 维度,增加一个维度 [1, num_codebooks, seq_len]
  881. if len(codes_tensor.shape) == 2:
  882. codes_tensor = codes_tensor.unsqueeze(0)
  883. print(f"Loaded codes shape: {codes_tensor.shape}")
  884. # 3. 直接从 codes 重建音频 (Decoding)
  885. # 注意:fish_speech 的 model.from_indices 通常接受的输入是 LongTensor
  886. fake_audio = model.from_indices(codes_tensor)
  887. # 4. 后处理与保存
  888. # fake_audio 形状通常为 [B, C, T]
  889. audio_np = fake_audio.squeeze().cpu().numpy()
  890. # 如果是多声道,转置为 soundfile 要求的 (samples, channels)
  891. if len(audio_np.shape) == 2:
  892. audio_np = audio_np.T
  893. sf.write(output_path, audio_np, sample_rate)
  894. print(f"重建完成。音频已保存至: {output_path}")