modded_dac.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080
  1. import math
  2. import typing as tp
  3. from dataclasses import dataclass
  4. from typing import List, Optional, Union
  5. import numpy as np
  6. import torch
  7. from audiotools import AudioSignal
  8. from audiotools.ml import BaseModel
  9. from dac.model.base import CodecMixin
  10. from dac.nn.layers import Snake1d, WNConv1d, WNConvTranspose1d
  11. from torch import Tensor, nn
  12. from torch.nn import functional as F
  13. from torch.nn.utils.parametrizations import weight_norm
  14. from torch.nn.utils.parametrize import remove_parametrizations
  15. @dataclass
  16. class VQResult:
  17. z: torch.Tensor
  18. codes: torch.Tensor
  19. latents: torch.Tensor
  20. codebook_loss: torch.Tensor
  21. commitment_loss: torch.Tensor
  22. semantic_distill_z: torch.Tensor | None = None
  23. def find_multiple(n: int, k: int) -> int:
  24. if n % k == 0:
  25. return n
  26. return n + k - (n % k)
  27. @dataclass
  28. class ModelArgs:
  29. block_size: int = 2048
  30. n_layer: int = 8
  31. n_head: int = 8
  32. dim: int = 512
  33. intermediate_size: int = 1536
  34. n_local_heads: int = -1
  35. head_dim: int = 64
  36. rope_base: float = 10000
  37. norm_eps: float = 1e-5
  38. dropout_rate: float = 0.1
  39. attn_dropout_rate: float = 0.1
  40. channels_first: bool = True # to be compatible with conv1d input/output
  41. pos_embed_type: str = "rope" # can be "rope" or "conformer"
  42. max_relative_position: int = 128 # for conformer-style relative position embedding
  43. window_size: int = 512 # for window limited attention
  44. def __post_init__(self):
  45. if self.n_local_heads == -1:
  46. self.n_local_heads = self.n_head
  47. if self.intermediate_size is None:
  48. hidden_dim = 4 * self.dim
  49. n_hidden = int(2 * hidden_dim / 3)
  50. self.intermediate_size = find_multiple(n_hidden, 256)
  51. assert self.pos_embed_type in [
  52. "rope",
  53. "conformer",
  54. ], "pos_embed_type must be either 'rope' or 'conformer'"
  55. class KVCache(nn.Module):
  56. def __init__(
  57. self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16
  58. ):
  59. super().__init__()
  60. cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
  61. self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
  62. self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
  63. def update(self, input_pos, k_val, v_val):
  64. # input_pos: [S], k_val: [B, H, S, D]
  65. assert input_pos.shape[0] == k_val.shape[2]
  66. k_out = self.k_cache
  67. v_out = self.v_cache
  68. k_out[:, :, input_pos] = k_val
  69. v_out[:, :, input_pos] = v_val
  70. return (
  71. k_out[:, :, : input_pos.max() + 1, :],
  72. v_out[:, :, : input_pos.max() + 1, :],
  73. )
  74. def clear_cache(self, prompt_len):
  75. self.k_cache[:, :, prompt_len:, :] = torch.zeros_like(
  76. self.k_cache[:, :, prompt_len:, :]
  77. )
  78. self.v_cache[:, :, prompt_len:, :] = torch.zeros_like(
  79. self.v_cache[:, :, prompt_len:, :]
  80. )
  81. class Transformer(nn.Module):
  82. def __init__(self, config: ModelArgs) -> None:
  83. super().__init__()
  84. self.config = config
  85. self.layers = nn.ModuleList(
  86. TransformerBlock(config) for _ in range(config.n_layer)
  87. )
  88. self.norm = RMSNorm(config.dim, eps=config.norm_eps)
  89. # Only compute RoPE frequencies if using RoPE
  90. if config.pos_embed_type == "rope":
  91. freqs_cis = precompute_freqs_cis(
  92. 327680, self.config.head_dim, self.config.rope_base
  93. )
  94. self.register_buffer("freqs_cis", freqs_cis, persistent=False)
  95. else:
  96. self.register_buffer("freqs_cis", None)
  97. causal_mask = torch.tril(torch.ones(32768, 32768, dtype=torch.bool))
  98. self.register_buffer("causal_mask", causal_mask, persistent=False)
  99. self.max_batch_size = -1
  100. self.max_seq_length = -1
  101. self.use_kv_cache = False
  102. def setup_caches(self, max_batch_size, max_seq_length):
  103. """
  104. This method will only be called during inference when using KV cache.
  105. """
  106. head_dim = self.config.dim // self.config.n_head
  107. max_seq_length = find_multiple(max_seq_length, 8)
  108. self.max_seq_length = max_seq_length
  109. self.max_batch_size = max_batch_size
  110. dtype = self.norm.weight.dtype
  111. device = self.norm.weight.device
  112. for b in self.layers:
  113. b.attention.kv_cache = KVCache(
  114. max_batch_size,
  115. max_seq_length,
  116. self.config.n_local_heads,
  117. head_dim,
  118. dtype,
  119. ).to(device)
  120. self.use_kv_cache = True
  121. def forward(
  122. self,
  123. x: Tensor,
  124. input_pos: Optional[Tensor] = None,
  125. mask: Optional[Tensor] = None,
  126. ) -> Tensor:
  127. if self.config.pos_embed_type == "rope":
  128. assert (
  129. self.freqs_cis is not None
  130. ), "RoPE frequencies must be initialized for RoPE positional embedding"
  131. # print("MAX", input_pos.max())
  132. freqs_cis = self.freqs_cis[input_pos]
  133. else:
  134. freqs_cis = None
  135. if mask is None: # in case of non-causal model
  136. if not self.training and self.use_kv_cache:
  137. mask = self.causal_mask[None, None, input_pos]
  138. mask = mask[..., : input_pos.max() + 1]
  139. else:
  140. mask = self.causal_mask[None, None, input_pos]
  141. mask = mask[..., input_pos]
  142. for i, layer in enumerate(self.layers):
  143. x = layer(x, input_pos, freqs_cis, mask)
  144. x = self.norm(x)
  145. return x
  146. class TransformerBlock(nn.Module):
  147. def __init__(self, config: ModelArgs) -> None:
  148. super().__init__()
  149. self.attention = Attention(config)
  150. self.feed_forward = FeedForward(config)
  151. self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
  152. self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
  153. self.attention_layer_scale = LayerScale(config.dim, inplace=True)
  154. self.ffn_layer_scale = LayerScale(config.dim, inplace=True)
  155. def forward(
  156. self,
  157. x: Tensor,
  158. input_pos: Tensor,
  159. freqs_cis: Tensor,
  160. mask: Tensor,
  161. ) -> Tensor:
  162. h = x + self.attention_layer_scale(
  163. self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
  164. )
  165. out = h + self.ffn_layer_scale(self.feed_forward(self.ffn_norm(h)))
  166. return out
  167. class Attention(nn.Module):
  168. def __init__(self, config: ModelArgs):
  169. super().__init__()
  170. assert config.dim % config.n_head == 0
  171. total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
  172. # key, query, value projections for all heads, but in a batch
  173. self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
  174. self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
  175. self.kv_cache = None
  176. self.n_head = config.n_head
  177. self.head_dim = config.head_dim
  178. self.n_local_heads = config.n_local_heads
  179. self.dim = config.dim
  180. self.attn_dropout_rate = config.attn_dropout_rate
  181. self.pos_embed_type = config.pos_embed_type
  182. # Add relative position embedding for conformer-style
  183. if self.pos_embed_type == "conformer":
  184. self.max_relative_position = config.max_relative_position
  185. num_pos_embeddings = 2 * config.max_relative_position + 1
  186. self.rel_pos_embeddings = nn.Parameter(
  187. torch.zeros(num_pos_embeddings, self.head_dim)
  188. )
  189. nn.init.normal_(self.rel_pos_embeddings, mean=0.0, std=0.02)
  190. def _compute_conformer_pos_scores(self, q: Tensor, seqlen: int) -> Tensor:
  191. # q: [B, H, S, D]
  192. # Returns: [B, H, S, S]
  193. positions = torch.arange(seqlen, device=q.device)
  194. relative_positions = positions.unsqueeze(1) - positions.unsqueeze(0) # [S, S]
  195. relative_positions = torch.clamp(
  196. relative_positions + self.max_relative_position,
  197. 0,
  198. 2 * self.max_relative_position,
  199. )
  200. rel_embeddings = self.rel_pos_embeddings[relative_positions] # [S, S, D]
  201. # Compute attention scores with relative position embeddings
  202. q = q.transpose(1, 2) # [B, S, H, D]
  203. rel_logits = torch.matmul(q, rel_embeddings.transpose(-2, -1)) # [B, S, H, S]
  204. rel_logits = rel_logits.transpose(1, 2) # [B, H, S, S]
  205. return rel_logits
  206. def forward(
  207. self,
  208. x: Tensor,
  209. freqs_cis: Tensor,
  210. mask: Tensor,
  211. input_pos: Optional[Tensor] = None,
  212. ) -> Tensor:
  213. bsz, seqlen, _ = x.shape
  214. print(f"Attention forward self.n_local_heads {self.n_local_heads}, self.head_dim {self.head_dim}")
  215. kv_size = self.n_local_heads * self.head_dim
  216. q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
  217. context_seqlen = seqlen
  218. q = q.view(bsz, seqlen, self.n_head, self.head_dim)
  219. k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
  220. v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
  221. # if self.pos_embed_type == "rope":
  222. # q = apply_rotary_emb(q, freqs_cis)
  223. # k = apply_rotary_emb(k, freqs_cis)
  224. q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
  225. if self.kv_cache is not None:
  226. k, v = self.kv_cache.update(input_pos, k, v)
  227. if input_pos is not None:
  228. # =========================
  229. # 🔥 KV cache window 裁剪(核心优化)
  230. # =========================
  231. max_context = 4096 # ⭐ 推荐 4K 或 8K
  232. # 当前有效长度
  233. seq_len = int(input_pos.max().item()) + 1
  234. # window 起点
  235. start = max(0, seq_len - max_context)
  236. # 裁剪 KV
  237. k = k[:, :, start:seq_len, :]
  238. v = v[:, :, start:seq_len, :]
  239. # =========================
  240. # 🔥 同步裁剪 mask(如果有)
  241. # =========================
  242. if mask is not None:
  243. mask = mask[:, :, :, start:seq_len]
  244. # =========================
  245. # 🔥 同步裁剪 RoPE(关键,不然会炸)
  246. # =========================
  247. print(f"input_pos.dtype {input_pos.dtype}")
  248. assert input_pos.dtype == torch.long
  249. freqs_cis = torch.index_select(freqs_cis, 0, input_pos.long())
  250. if self.pos_embed_type == "rope":
  251. q = apply_rotary_emb(q, freqs_cis)
  252. k = apply_rotary_emb(k, freqs_cis)
  253. k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  254. v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  255. if self.pos_embed_type == "conformer":
  256. # Compute attention scores
  257. scale = 1.0 / math.sqrt(self.head_dim)
  258. scores = torch.matmul(q, k.transpose(-2, -1)) * scale
  259. # Add relative position embeddings for conformer-style
  260. rel_scores = self._compute_conformer_pos_scores(q, seqlen)
  261. scores = scores + rel_scores
  262. # Apply attention
  263. if mask is not None:
  264. scores = scores.masked_fill(~mask, float("-inf"))
  265. attn = F.softmax(scores, dim=-1)
  266. if self.attn_dropout_rate > 0 and self.training:
  267. attn = F.dropout(attn, p=self.attn_dropout_rate)
  268. y = torch.matmul(attn, v)
  269. else:
  270. y = F.scaled_dot_product_attention(
  271. q,
  272. k,
  273. v,
  274. dropout_p=self.attn_dropout_rate if self.training else 0.0,
  275. attn_mask=mask,
  276. )
  277. # is_causal=True)
  278. y = (
  279. y.transpose(1, 2)
  280. .contiguous()
  281. .view(bsz, seqlen, self.head_dim * self.n_head)
  282. )
  283. y = self.wo(y)
  284. return y
  285. class FeedForward(nn.Module):
  286. def __init__(self, config: ModelArgs) -> None:
  287. super().__init__()
  288. self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  289. self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  290. self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
  291. self.dropout = nn.Dropout(config.dropout_rate)
  292. def forward(self, x: Tensor) -> Tensor:
  293. return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
  294. class RMSNorm(nn.Module):
  295. def __init__(self, dim: int, eps: float = 1e-5):
  296. super().__init__()
  297. self.eps = eps
  298. self.weight = nn.Parameter(torch.ones(dim))
  299. def _norm(self, x):
  300. return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
  301. def forward(self, x: Tensor) -> Tensor:
  302. output = self._norm(x.float()).type_as(x)
  303. return output * self.weight
  304. class LayerScale(nn.Module):
  305. def __init__(
  306. self,
  307. dim: int,
  308. init_values: Union[float, Tensor] = 1e-2,
  309. inplace: bool = False,
  310. ) -> None:
  311. super().__init__()
  312. self.inplace = inplace
  313. self.gamma = nn.Parameter(init_values * torch.ones(dim))
  314. def forward(self, x: Tensor) -> Tensor:
  315. return x.mul_(self.gamma) if self.inplace else x * self.gamma
  316. class WindowLimitedTransformer(Transformer):
  317. """
  318. Transformer with window limited attention, causal.
  319. """
  320. def __init__(
  321. self,
  322. config: ModelArgs,
  323. input_dim: int = 512,
  324. window_size: Optional[int] = None,
  325. causal: bool = True,
  326. look_ahead_conv: nn.Module = None,
  327. ):
  328. super().__init__(config)
  329. self.window_size = window_size
  330. self.causal = causal
  331. self.channels_first = config.channels_first
  332. self.look_ahead_conv = (
  333. look_ahead_conv if look_ahead_conv is not None else nn.Identity()
  334. )
  335. self.input_proj = (
  336. nn.Linear(input_dim, config.dim)
  337. if input_dim != config.dim
  338. else nn.Identity()
  339. )
  340. self.output_proj = (
  341. nn.Linear(config.dim, input_dim)
  342. if input_dim != config.dim
  343. else nn.Identity()
  344. )
  345. def make_window_limited_mask(
  346. self,
  347. max_length: int,
  348. x_lens: Optional[Tensor] = None,
  349. ) -> Tensor:
  350. """
  351. Make mask to form window limited attention.
  352. """
  353. if self.causal:
  354. mask = torch.tril(torch.ones(max_length, max_length))
  355. row_indices = torch.arange(max_length).view(-1, 1)
  356. window_size = self.window_size or max_length
  357. valid_range = (row_indices - window_size + 1).clamp(min=0)
  358. column_indices = torch.arange(max_length)
  359. mask = (column_indices >= valid_range) & mask.bool()
  360. else:
  361. raise NotImplementedError
  362. mask = mask.bool()[None, None]
  363. return mask
  364. def make_mask(
  365. self,
  366. max_length: int,
  367. x_lens: Optional[Tensor] = None,
  368. ) -> Tensor:
  369. """
  370. Make ordinary mask if window size is not specified.
  371. """
  372. if self.causal:
  373. mask = torch.tril(torch.ones(max_length, max_length))
  374. else:
  375. mask = torch.ones(max_length, max_length)
  376. mask = mask.bool()[None, None]
  377. for i, x_len in enumerate(x_lens):
  378. mask[:x_len, i] = 0
  379. mask = mask.bool()[None, None]
  380. return mask
  381. def forward(
  382. self,
  383. x: Tensor,
  384. x_lens: Optional[Tensor] = None,
  385. ) -> Tensor:
  386. if self.channels_first:
  387. x = x.transpose(1, 2)
  388. x = self.input_proj(x) # (B, T, D)
  389. x = self.look_ahead_conv(x)
  390. input_pos = torch.arange(x.shape[1], device=x.device)
  391. # construct mask to form window limited attention
  392. max_length = x.shape[1]
  393. if self.window_size is not None:
  394. mask = self.make_window_limited_mask(max_length, x_lens)
  395. else:
  396. mask = self.make_mask(max_length, x_lens)
  397. mask = mask.to(x.device)
  398. x = super().forward(x, input_pos, mask)
  399. x = self.output_proj(x) # (B, T, D)
  400. if self.channels_first:
  401. x = x.transpose(1, 2)
  402. return x
  403. def precompute_freqs_cis(
  404. seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16
  405. ) -> Tensor:
  406. freqs = 1.0 / (
  407. base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
  408. )
  409. t = torch.arange(seq_len, device=freqs.device)
  410. freqs = torch.outer(t, freqs)
  411. freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
  412. cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
  413. return cache.to(dtype=dtype)
  414. def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
  415. xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
  416. freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
  417. x_out2 = torch.stack(
  418. [
  419. xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
  420. xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
  421. ],
  422. -1,
  423. )
  424. x_out2 = x_out2.flatten(3)
  425. return x_out2.type_as(x)
  426. def init_weights(m):
  427. if isinstance(m, nn.Conv1d):
  428. nn.init.trunc_normal_(m.weight, std=0.02)
  429. nn.init.constant_(m.bias, 0)
  430. def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
  431. """Remove padding from x, handling properly zero padding. Only for 1d!"""
  432. padding_left, padding_right = paddings
  433. assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
  434. assert (padding_left + padding_right) <= x.shape[-1]
  435. end = x.shape[-1] - padding_right
  436. return x[..., padding_left:end]
  437. def get_extra_padding_for_conv1d(
  438. x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
  439. ) -> int:
  440. """See `pad_for_conv1d`."""
  441. length = x.shape[-1]
  442. n_frames = (length - kernel_size + padding_total) / stride + 1
  443. ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
  444. return ideal_length - length
  445. def pad1d(
  446. x: torch.Tensor,
  447. paddings: tp.Tuple[int, int],
  448. mode: str = "zeros",
  449. value: float = 0.0,
  450. ):
  451. """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
  452. If this is the case, we insert extra 0 padding to the right
  453. before the reflection happen.
  454. """
  455. length = x.shape[-1]
  456. padding_left, padding_right = paddings
  457. assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
  458. if mode == "reflect":
  459. max_pad = max(padding_left, padding_right)
  460. extra_pad = 0
  461. if length <= max_pad:
  462. extra_pad = max_pad - length + 1
  463. x = F.pad(x, (0, extra_pad))
  464. padded = F.pad(x, paddings, mode, value)
  465. end = padded.shape[-1] - extra_pad
  466. return padded[..., :end]
  467. else:
  468. return F.pad(x, paddings, mode, value)
  469. class CausalConvNet(nn.Module):
  470. def __init__(
  471. self,
  472. in_channels,
  473. out_channels,
  474. kernel_size,
  475. dilation=1,
  476. stride=1,
  477. groups=1,
  478. padding=None,
  479. ):
  480. super(CausalConvNet, self).__init__()
  481. self.conv = nn.Conv1d(
  482. in_channels,
  483. out_channels,
  484. kernel_size,
  485. stride=stride,
  486. dilation=dilation,
  487. groups=groups,
  488. )
  489. self.stride = stride
  490. self.kernel_size = (kernel_size - 1) * dilation + 1
  491. self.dilation = dilation
  492. self.padding = self.kernel_size - self.stride
  493. def forward(self, x):
  494. pad = self.padding
  495. extra_padding = get_extra_padding_for_conv1d(
  496. x, self.kernel_size, self.stride, pad
  497. )
  498. x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
  499. return self.conv(x).contiguous()
  500. def weight_norm(self, name="weight", dim=0):
  501. self.conv = weight_norm(self.conv, name=name, dim=dim)
  502. return self
  503. def remove_weight_norm(self):
  504. self.conv = remove_parametrizations(self.conv)
  505. return self
  506. class CausalTransConvNet(nn.Module):
  507. def __init__(
  508. self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
  509. ):
  510. super(CausalTransConvNet, self).__init__()
  511. self.conv = nn.ConvTranspose1d(
  512. in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
  513. )
  514. self.stride = stride
  515. self.kernel_size = kernel_size
  516. def forward(self, x):
  517. x = self.conv(x)
  518. pad = self.kernel_size - self.stride
  519. padding_right = math.ceil(pad)
  520. padding_left = pad - padding_right
  521. x = unpad1d(x, (padding_left, padding_right))
  522. return x.contiguous()
  523. def weight_norm(self, name="weight", dim=0):
  524. self.conv = weight_norm(self.conv, name=name, dim=dim)
  525. return self
  526. def remove_weight_norm(self):
  527. self.conv = remove_parametrizations(self.conv)
  528. return self
  529. def CausalWNConv1d(*args, **kwargs):
  530. return CausalConvNet(*args, **kwargs).weight_norm()
  531. def CausalWNConvTranspose1d(*args, **kwargs):
  532. return CausalTransConvNet(*args, **kwargs).weight_norm()
  533. class ResidualUnit(nn.Module):
  534. def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False):
  535. super().__init__()
  536. conv_class = CausalWNConv1d if causal else WNConv1d
  537. pad = ((7 - 1) * dilation) // 2
  538. self.block = nn.Sequential(
  539. Snake1d(dim),
  540. conv_class(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
  541. Snake1d(dim),
  542. conv_class(dim, dim, kernel_size=1),
  543. )
  544. self.causal = causal
  545. def forward(self, x):
  546. y = self.block(x)
  547. pad = x.shape[-1] - y.shape[-1]
  548. if pad > 0:
  549. if self.causal:
  550. x = x[..., :-pad]
  551. else:
  552. x = x[..., pad // 2: -pad // 2]
  553. return x + y
  554. class EncoderBlock(nn.Module):
  555. def __init__(
  556. self,
  557. dim: int = 16,
  558. stride: int = 1,
  559. causal: bool = False,
  560. n_t_layer: int = 0,
  561. transformer_general_config=None,
  562. ):
  563. super().__init__()
  564. conv_class = CausalWNConv1d if causal else WNConv1d
  565. transformer_module = (
  566. nn.Identity()
  567. if n_t_layer == 0
  568. else (
  569. WindowLimitedTransformer(
  570. causal=causal,
  571. input_dim=dim,
  572. window_size=getattr(transformer_general_config, "window_size", 512),
  573. config=transformer_general_config(
  574. n_layer=n_t_layer,
  575. n_head=dim // 64,
  576. dim=dim,
  577. intermediate_size=dim * 3,
  578. ),
  579. )
  580. )
  581. )
  582. self.block = nn.Sequential(
  583. ResidualUnit(dim // 2, dilation=1, causal=causal),
  584. ResidualUnit(dim // 2, dilation=3, causal=causal),
  585. ResidualUnit(dim // 2, dilation=9, causal=causal),
  586. Snake1d(dim // 2),
  587. conv_class(
  588. dim // 2,
  589. dim,
  590. kernel_size=2 * stride,
  591. stride=stride,
  592. padding=math.ceil(stride / 2),
  593. ),
  594. transformer_module,
  595. )
  596. def forward(self, x):
  597. return self.block(x)
  598. class Encoder(nn.Module):
  599. def __init__(
  600. self,
  601. d_model: int = 64,
  602. strides: list = [2, 4, 8, 8],
  603. d_latent: int = 64,
  604. n_transformer_layers: list = [0, 0, 4, 4],
  605. transformer_general_config: ModelArgs = None,
  606. causal: bool = False,
  607. ):
  608. super().__init__()
  609. conv_class = CausalWNConv1d if causal else WNConv1d
  610. # Create first convolution
  611. self.block = [conv_class(1, d_model, kernel_size=7, padding=3)]
  612. # Create EncoderBlocks that double channels as they downsample by `stride`
  613. for stride, n_t_layer in zip(strides, n_transformer_layers):
  614. d_model *= 2
  615. self.block += [
  616. EncoderBlock(
  617. d_model,
  618. stride=stride,
  619. causal=causal,
  620. n_t_layer=n_t_layer,
  621. transformer_general_config=transformer_general_config,
  622. )
  623. ]
  624. # Create last convolution
  625. self.block += [
  626. Snake1d(d_model),
  627. conv_class(d_model, d_latent, kernel_size=3, padding=1),
  628. ]
  629. # Wrap black into nn.Sequential
  630. self.block = nn.Sequential(*self.block)
  631. self.enc_dim = d_model
  632. def forward(self, x):
  633. return self.block(x)
  634. class DecoderBlock(nn.Module):
  635. def __init__(
  636. self,
  637. input_dim: int = 16,
  638. output_dim: int = 8,
  639. stride: int = 1,
  640. causal: bool = False,
  641. n_t_layer: int = 0,
  642. transformer_general_config=None,
  643. ):
  644. super().__init__()
  645. conv_trans_class = CausalWNConvTranspose1d if causal else WNConvTranspose1d
  646. transformer_module = (
  647. nn.Identity()
  648. if n_t_layer == 0
  649. else (
  650. WindowLimitedTransformer(
  651. causal=causal,
  652. input_dim=input_dim,
  653. window_size=None,
  654. config=transformer_general_config(
  655. n_layer=n_t_layer,
  656. n_head=input_dim // 64,
  657. dim=input_dim,
  658. intermediate_size=input_dim * 3,
  659. ),
  660. )
  661. )
  662. )
  663. self.block = nn.Sequential(
  664. # transformer_module,
  665. Snake1d(input_dim),
  666. conv_trans_class(
  667. input_dim,
  668. output_dim,
  669. kernel_size=2 * stride,
  670. stride=stride,
  671. padding=math.ceil(stride / 2),
  672. ),
  673. ResidualUnit(output_dim, dilation=1, causal=causal),
  674. ResidualUnit(output_dim, dilation=3, causal=causal),
  675. ResidualUnit(output_dim, dilation=9, causal=causal),
  676. )
  677. def forward(self, x):
  678. return self.block(x)
  679. class Decoder(nn.Module):
  680. def __init__(
  681. self,
  682. input_channel,
  683. channels,
  684. rates,
  685. d_out: int = 1,
  686. causal: bool = False,
  687. n_transformer_layers: list = [0, 0, 0, 0],
  688. transformer_general_config=None,
  689. ):
  690. super().__init__()
  691. conv_class = CausalWNConv1d if causal else WNConv1d
  692. # Add first conv layer
  693. layers = [conv_class(input_channel, channels, kernel_size=7, padding=3)]
  694. # Add upsampling + MRF blocks
  695. for i, (stride, n_t_layer) in enumerate(zip(rates, n_transformer_layers)):
  696. input_dim = channels // 2 ** i
  697. output_dim = channels // 2 ** (i + 1)
  698. layers += [
  699. DecoderBlock(
  700. input_dim,
  701. output_dim,
  702. stride,
  703. causal=causal,
  704. n_t_layer=n_t_layer,
  705. transformer_general_config=transformer_general_config,
  706. )
  707. ]
  708. # Add final conv layer
  709. layers += [
  710. Snake1d(output_dim),
  711. conv_class(output_dim, d_out, kernel_size=7, padding=3),
  712. nn.Tanh(),
  713. ]
  714. self.model = nn.Sequential(*layers)
  715. def forward(self, x):
  716. return self.model(x)
  717. class DAC(BaseModel, CodecMixin):
  718. def __init__(
  719. self,
  720. encoder_dim: int = 64,
  721. encoder_rates: List[int] = [2, 4, 8, 8],
  722. latent_dim: int = None,
  723. decoder_dim: int = 1536,
  724. decoder_rates: List[int] = [8, 8, 4, 2],
  725. quantizer: torch.nn.Module = None,
  726. sample_rate: int = 44100,
  727. causal: bool = True,
  728. encoder_transformer_layers: List[int] = [0, 0, 0, 0],
  729. decoder_transformer_layers: List[int] = [0, 0, 0, 0],
  730. overwrite_decoder: torch.nn.Module = None,
  731. transformer_general_config=None,
  732. ):
  733. super().__init__()
  734. self.encoder_dim = encoder_dim
  735. self.encoder_rates = encoder_rates
  736. self.decoder_dim = decoder_dim
  737. self.decoder_rates = decoder_rates
  738. self.sample_rate = sample_rate
  739. if latent_dim is None:
  740. latent_dim = encoder_dim * (2 ** len(encoder_rates))
  741. self.latent_dim = latent_dim
  742. self.hop_length = np.prod(encoder_rates)
  743. self.encoder = Encoder(
  744. encoder_dim,
  745. encoder_rates,
  746. latent_dim,
  747. causal=causal,
  748. n_transformer_layers=encoder_transformer_layers,
  749. transformer_general_config=transformer_general_config,
  750. )
  751. self.quantizer = quantizer
  752. if overwrite_decoder is not None:
  753. self.decoder = overwrite_decoder
  754. else:
  755. self.decoder = Decoder(
  756. latent_dim,
  757. decoder_dim,
  758. decoder_rates,
  759. causal=causal,
  760. n_transformer_layers=decoder_transformer_layers,
  761. transformer_general_config=transformer_general_config,
  762. )
  763. self.sample_rate = sample_rate
  764. self.apply(init_weights)
  765. self.delay = self.get_delay()
  766. self.frame_length = self.hop_length * 4
  767. def preprocess(self, audio_data, sample_rate):
  768. if sample_rate is None:
  769. sample_rate = self.sample_rate
  770. assert sample_rate == self.sample_rate
  771. length = audio_data.shape[-1]
  772. right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
  773. audio_data = nn.functional.pad(audio_data, (0, right_pad))
  774. return audio_data
  775. def encode(
  776. self,
  777. audio_data: torch.Tensor,
  778. audio_lengths: torch.Tensor = None,
  779. n_quantizers: int = None,
  780. **kwargs,
  781. ):
  782. """Encode given audio data and return quantized latent codes
  783. Parameters
  784. ----------
  785. audio_data : Tensor[B x T]
  786. Audio data to encode
  787. n_quantizers : int, optional
  788. Number of quantizers to use, by default None
  789. If None, all quantizers are used.
  790. Returns
  791. -------
  792. dict
  793. A dictionary with the following keys:
  794. "z" : Tensor[B x D x T]
  795. Quantized continuous representation of input
  796. "codes" : Tensor[B x N x T]
  797. Codebook indices for each codebook
  798. (quantized discrete representation of input)
  799. "latents" : Tensor[B x N*D x T]
  800. Projected latents (continuous representation of input before quantization)
  801. "vq/commitment_loss" : Tensor[1]
  802. Commitment loss to train encoder to predict vectors closer to codebook
  803. entries
  804. "vq/codebook_loss" : Tensor[1]
  805. Codebook loss to update the codebook
  806. "length" : int
  807. Number of samples in input audio
  808. """
  809. # pad to multiple of self.frame_length
  810. if audio_data.ndim == 2:
  811. audio_data = audio_data.unsqueeze(1)
  812. length = audio_data.shape[-1]
  813. right_pad = math.ceil(length / self.frame_length) * self.frame_length - length
  814. audio_data = nn.functional.pad(audio_data, (0, right_pad))
  815. if audio_lengths is None:
  816. audio_lengths = torch.LongTensor([length + right_pad]).to(audio_data.device)
  817. z = self.encoder(audio_data)
  818. vq_results = self.quantizer(z, n_quantizers, **kwargs)
  819. indices = vq_results.codes
  820. indices_lens = torch.ceil(audio_lengths / self.frame_length).long()
  821. return indices, indices_lens
  822. def from_indices(self, indices: torch.Tensor):
  823. z = self.quantizer.decode(indices)
  824. return self.decoder(z)
  825. def decode(self, z: torch.Tensor):
  826. """Decode given latent codes and return audio data
  827. Parameters
  828. ----------
  829. z : Tensor[B x D x T]
  830. Quantized continuous representation of input
  831. length : int, optional
  832. Number of samples in output audio, by default None
  833. Returns
  834. -------
  835. dict
  836. A dictionary with the following keys:
  837. "audio" : Tensor[B x 1 x length]
  838. Decoded audio data.
  839. """
  840. return self.decoder(z)
  841. def forward(
  842. self,
  843. audio_data: torch.Tensor,
  844. template: torch.Tensor = None,
  845. mask: torch.Tensor = None,
  846. sample_rate: int = None,
  847. n_quantizers: int = None,
  848. **kwargs,
  849. ):
  850. """Model forward pass
  851. Parameters
  852. ----------
  853. audio_data : Tensor[B x 1 x T]
  854. Audio data to encode
  855. sample_rate : int, optional
  856. Sample rate of audio data in Hz, by default None
  857. If None, defaults to `self.sample_rate`
  858. n_quantizers : int, optional
  859. Number of quantizers to use, by default None.
  860. If None, all quantizers are used.
  861. Returns
  862. -------
  863. dict
  864. A dictionary with the following keys:
  865. "z" : Tensor[B x D x T]
  866. Quantized continuous representation of input
  867. "codes" : Tensor[B x N x T]
  868. Codebook indices for each codebook
  869. (quantized discrete representation of input)
  870. "latents" : Tensor[B x N*D x T]
  871. Projected latents (continuous representation of input before quantization)
  872. "vq/commitment_loss" : Tensor[1]
  873. Commitment loss to train encoder to predict vectors closer to codebook
  874. entries
  875. "vq/codebook_loss" : Tensor[1]
  876. Codebook loss to update the codebook
  877. "length" : int
  878. Number of samples in input audio
  879. "audio" : Tensor[B x 1 x length]
  880. Decoded audio data.
  881. """
  882. length = audio_data.shape[-1]
  883. audio_data = self.preprocess(audio_data, sample_rate)
  884. vq_results = self.encode(audio_data, n_quantizers, **kwargs)
  885. z = vq_results[0] if isinstance(vq_results, tuple) else vq_results.z
  886. x = self.decode(z)
  887. return x[..., :length], vq_results
  888. if __name__ == "__main__":
  889. import hydra
  890. import numpy as np
  891. import soundfile as sf
  892. import torch
  893. from omegaconf import OmegaConf
  894. # 配置路径
  895. config_path = "fish_speech/configs/modded_dac_vq.yaml"
  896. checkpoint_path = "checkpoints/s2-pro/codec.pth"
  897. codes_path = "./output/codes_0.npy" # 你的 codes 文件路径
  898. output_path = "reconstructed_from_codes.wav"
  899. sample_rate = 44100 # 请确保采样率与模型训练时一致
  900. with torch.inference_mode():
  901. # 1. 初始化模型
  902. model = hydra.utils.instantiate(OmegaConf.load(config_path))
  903. new_sd = torch.load(checkpoint_path, map_location="cpu")
  904. model.load_state_dict(new_sd, strict=False)
  905. model.cuda()
  906. model.eval()
  907. # 2. 加载外部 codes (.npy)
  908. # 预期 shape 通常为 [num_codebooks, seq_len] 或 [1, num_codebooks, seq_len]
  909. codes_np = np.load(codes_path)
  910. codes_tensor = torch.from_numpy(codes_np).to(torch.long).cuda()
  911. # 如果 codes 没有 batch 维度,增加一个维度 [1, num_codebooks, seq_len]
  912. if len(codes_tensor.shape) == 2:
  913. codes_tensor = codes_tensor.unsqueeze(0)
  914. print(f"Loaded codes shape: {codes_tensor.shape}")
  915. # 3. 直接从 codes 重建音频 (Decoding)
  916. # 注意:fish_speech 的 model.from_indices 通常接受的输入是 LongTensor
  917. fake_audio = model.from_indices(codes_tensor)
  918. # 4. 后处理与保存
  919. # fake_audio 形状通常为 [B, C, T]
  920. audio_np = fake_audio.squeeze().cpu().numpy()
  921. # 如果是多声道,转置为 soundfile 要求的 (samples, channels)
  922. if len(audio_np.shape) == 2:
  923. audio_np = audio_np.T
  924. sf.write(output_path, audio_np, sample_rate)
  925. print(f"重建完成。音频已保存至: {output_path}")