modules_old.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699
  1. import math
  2. from typing import Optional
  3. import torch
  4. from einops import rearrange
  5. from torch import nn
  6. from torch.nn import functional as F
  7. try:
  8. from xformers.ops import memory_efficient_attention
  9. except ImportError as e:
  10. memory_efficient_attention = None
  11. class AlibiPostionEmbedding(nn.Module):
  12. def __init__(self, nheads, maxpos):
  13. super().__init__()
  14. context_position = torch.arange(maxpos)[:, None]
  15. memory_position = torch.arange(maxpos)[None, :]
  16. relative_position = memory_position - context_position
  17. relative_position = (
  18. torch.abs(relative_position).unsqueeze(0).expand(nheads, -1, -1)
  19. )
  20. self.slopes = torch.Tensor(self.get_slopes(nheads)) * -1
  21. alibi = self.slopes.unsqueeze(1).unsqueeze(1) * relative_position
  22. alibi = alibi.view(nheads, maxpos, maxpos)
  23. self.register_buffer("alibi", alibi)
  24. @staticmethod
  25. def get_slopes_power_of_2(n):
  26. start = 2 ** (-(2 ** -(math.log2(n) - 3)))
  27. ratio = start
  28. return [start * ratio**i for i in range(n)]
  29. def get_slopes(self, n):
  30. if math.log2(n).is_integer():
  31. return self.get_slopes_power_of_2(n)
  32. closest_power_of_2 = 2 ** math.floor(math.log2(n))
  33. return (
  34. self.get_slopes_power_of_2(closest_power_of_2)
  35. + self.get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
  36. )
  37. def __call__(self, x):
  38. # N, T, C
  39. return self.alibi[:, : x.size(1), : x.size(1)].to(x.device)
  40. class KVCache(nn.Module):
  41. def __init__(
  42. self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16
  43. ):
  44. super().__init__()
  45. cache_shape = (max_batch_size, max_seq_length, n_heads * head_dim)
  46. self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
  47. self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
  48. def update(self, input_pos, k_val, v_val):
  49. assert input_pos is not None, "input_pos should not be None"
  50. k_out = self.k_cache
  51. v_out = self.v_cache
  52. k_out[:, input_pos] = k_val
  53. v_out[:, input_pos] = v_val
  54. return k_out, v_out
  55. class MultiheadAttention(nn.Module):
  56. def __init__(self, d_model, nhead, dropout=0.1):
  57. super().__init__()
  58. assert d_model % nhead == 0
  59. self.nhead = nhead
  60. self.d_model = d_model
  61. self.head_dim = d_model // nhead
  62. self.q_proj = nn.Linear(d_model, d_model)
  63. self.k_proj = nn.Linear(d_model, d_model)
  64. self.v_proj = nn.Linear(d_model, d_model)
  65. self.out_proj = nn.Linear(d_model, d_model)
  66. self.dropout = nn.Dropout(dropout)
  67. self.kv_cache = None
  68. def forward(
  69. self,
  70. q,
  71. k,
  72. v,
  73. attn_mask=None,
  74. key_padding_mask=None,
  75. attn_bias=None,
  76. return_weights=False,
  77. input_pos=None,
  78. ):
  79. # (B, T, C)
  80. batch_size = q.size(0)
  81. q_length = q.size(1)
  82. q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)
  83. if self.kv_cache is not None:
  84. k, v = self.kv_cache.update(input_pos, k, v)
  85. k_length = k.size(1)
  86. if attn_bias is not None:
  87. assert attn_bias.size() == (
  88. self.nhead,
  89. q_length,
  90. k_length,
  91. ), f"Should be {(self.nhead, q_length, k_length)}. Got {attn_bias.size()}"
  92. attn_bias = attn_bias.unsqueeze(0).expand(batch_size, -1, -1, -1)
  93. if attn_mask is not None:
  94. assert attn_mask.size() == (
  95. q_length,
  96. k_length,
  97. ), f"Should be {(q_length, k_length)}. Got {attn_mask.size()}"
  98. assert attn_mask.dtype == torch.bool
  99. attn_mask = attn_mask.unsqueeze(0).expand(batch_size * self.nhead, -1, -1)
  100. if key_padding_mask is not None:
  101. assert key_padding_mask.size() == (
  102. batch_size,
  103. k_length,
  104. ), f"Should be {(batch_size, k_length)}. Got {key_padding_mask.size()}"
  105. assert key_padding_mask.dtype == torch.bool
  106. key_padding_mask = (
  107. key_padding_mask.unsqueeze(1)
  108. .unsqueeze(1)
  109. .expand(-1, self.nhead, -1, -1)
  110. )
  111. key_padding_mask = key_padding_mask.reshape(
  112. batch_size * self.nhead, 1, k_length
  113. )
  114. if attn_mask is None:
  115. attn_mask = key_padding_mask.expand(-1, q.size(1), -1)
  116. else:
  117. attn_mask = attn_mask.logical_or(key_padding_mask)
  118. if (
  119. return_weights is False
  120. and memory_efficient_attention is not None
  121. and q.device.type == "cuda"
  122. ):
  123. # (-> b, t,. n, d)
  124. q = rearrange(q, "b t (n d) -> b t n d", n=self.nhead)
  125. k = rearrange(k, "b t (n d) -> b t n d", n=self.nhead)
  126. v = rearrange(v, "b t (n d) -> b t n d", n=self.nhead)
  127. if attn_mask is not None:
  128. attn_mask = rearrange(attn_mask, "(b n) q k -> b n q k", n=self.nhead)
  129. if attn_bias is None:
  130. attn_bias = torch.zeros_like(
  131. attn_mask, dtype=q.dtype, device=q.device
  132. )
  133. attn_bias = attn_bias.masked_fill(attn_mask, float("-inf"))
  134. if attn_bias is not None:
  135. attn_bias = attn_bias.to(q.dtype)
  136. attn_output = memory_efficient_attention(
  137. q,
  138. k,
  139. v,
  140. attn_bias=attn_bias,
  141. scale=self.head_dim**-0.5,
  142. p=self.dropout.p,
  143. )
  144. attn_output = rearrange(attn_output, "b t n d -> b t (n d)", n=self.nhead)
  145. returned_weights = None
  146. else:
  147. q = rearrange(q, "b t (n d) -> (b n) t d", n=self.nhead)
  148. k = rearrange(k, "b t (n d) -> (b n) t d", n=self.nhead)
  149. v = rearrange(v, "b t (n d) -> (b n) t d", n=self.nhead)
  150. attn_weights = torch.bmm(q, k.mT) * (self.head_dim**-0.5)
  151. assert attn_weights.size() == (
  152. batch_size * self.nhead,
  153. q.size(1),
  154. k.size(1),
  155. )
  156. if attn_bias is not None:
  157. attn_bias = rearrange(attn_bias, "b n q k -> (b n) q k")
  158. attn_weights = attn_weights + attn_bias
  159. if attn_mask is not None:
  160. attn_weights = attn_weights.masked_fill(attn_mask, float("-inf"))
  161. attn_weights = F.softmax(attn_weights, dim=-1, dtype=attn_weights.dtype)
  162. returned_weights = attn_weights.view(
  163. batch_size, self.nhead, q.size(1), k.size(1)
  164. )
  165. attn_probs = self.dropout(attn_weights)
  166. attn_output = torch.bmm(attn_probs, v)
  167. attn_output = rearrange(attn_output, "(b n) t d -> b t (n d)", n=self.nhead)
  168. attn_output = self.out_proj(attn_output)
  169. return attn_output, returned_weights
  170. class GluMLP(nn.Module):
  171. def __init__(self, hidden_size=1024, intermediate_size=None, activation=nn.SiLU):
  172. super().__init__()
  173. if intermediate_size is None:
  174. intermediate_size = hidden_size * (11 / 3)
  175. intermediate_size = round(intermediate_size / 8) * 8
  176. self.hidden_size = hidden_size
  177. self.intermediate_size = intermediate_size
  178. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  179. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  180. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  181. self.act_fn = activation()
  182. def forward(self, x):
  183. return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  184. class RMSNorm(nn.Module):
  185. def __init__(self, hidden_size, eps=1e-6):
  186. """
  187. RMSNorm is equivalent to T5LayerNorm
  188. """
  189. super().__init__()
  190. self.weight = nn.Parameter(torch.ones(hidden_size))
  191. self.variance_epsilon = eps
  192. def forward(self, hidden_states):
  193. input_dtype = hidden_states.dtype
  194. hidden_states = hidden_states.to(torch.float32)
  195. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  196. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  197. return self.weight * hidden_states.to(input_dtype)
  198. class CrossAttentionLayer(nn.Module):
  199. def __init__(self, hidden_size=1024, intermediate_size=None, dropout=0.1):
  200. super().__init__()
  201. self.attn = MultiheadAttention(hidden_size, 1, dropout=dropout)
  202. self.mlp = GluMLP(hidden_size=hidden_size, intermediate_size=intermediate_size)
  203. self.input_layernorm_q = RMSNorm(hidden_size, eps=1e-6)
  204. self.input_layernorm_kv = RMSNorm(hidden_size, eps=1e-6)
  205. self.post_attention_layernorm = RMSNorm(hidden_size, eps=1e-6)
  206. def forward(
  207. self,
  208. tgt,
  209. memory,
  210. memory_key_padding_mask=None,
  211. input_pos=None,
  212. ):
  213. residual = tgt
  214. tgt, memory = self.input_layernorm_q(tgt), self.input_layernorm_kv(memory)
  215. x, attn_weights = self.attn(
  216. tgt,
  217. memory,
  218. memory,
  219. key_padding_mask=memory_key_padding_mask,
  220. return_weights=True,
  221. input_pos=input_pos,
  222. )
  223. residual = x + residual
  224. x = self.post_attention_layernorm(residual)
  225. x = self.mlp(x)
  226. x = x + residual
  227. return x, attn_weights
  228. class TransformerEncoderLayer(nn.Module):
  229. def __init__(self, hidden_size=1024, intermediate_size=None, nhead=16, dropout=0.1):
  230. super().__init__()
  231. self.attn = MultiheadAttention(hidden_size, nhead, dropout=dropout)
  232. self.mlp = GluMLP(hidden_size=hidden_size, intermediate_size=intermediate_size)
  233. self.input_layernorm = RMSNorm(hidden_size, eps=1e-6)
  234. self.post_attention_layernorm = RMSNorm(hidden_size, eps=1e-6)
  235. def forward(
  236. self, x, attn_bias=None, key_padding_mask=None, tgt_mask=None, input_pos=None
  237. ):
  238. residual = x
  239. x = self.input_layernorm(x)
  240. x, _ = self.attn(
  241. x,
  242. x,
  243. x,
  244. attn_bias=attn_bias,
  245. key_padding_mask=key_padding_mask,
  246. attn_mask=tgt_mask,
  247. return_weights=False,
  248. input_pos=input_pos,
  249. )
  250. residual = x + residual
  251. x = self.post_attention_layernorm(residual)
  252. x = self.mlp(x)
  253. x = x + residual
  254. return x
  255. class FishSpeechTransformer(nn.Module):
  256. def __init__(
  257. self,
  258. vocab_size,
  259. codebook_size,
  260. num_codebooks,
  261. hidden_size=1024,
  262. intermediate_size=None,
  263. nhead=16,
  264. num_encoder_layers=12,
  265. num_decoder_layers=12,
  266. dropout=0.1,
  267. alignment_position=-2,
  268. max_position=8192,
  269. ):
  270. super().__init__()
  271. self.encoder_embedding = nn.Embedding(vocab_size, hidden_size)
  272. self.decoder_embeddings = nn.ModuleList(
  273. [nn.Embedding(codebook_size, hidden_size) for _ in range(num_codebooks)]
  274. )
  275. self.decoder_head = nn.Linear(hidden_size, codebook_size * num_codebooks)
  276. self.codebook_size = codebook_size
  277. self.num_codebooks = num_codebooks
  278. self.encoder = nn.ModuleList(
  279. [
  280. TransformerEncoderLayer(
  281. hidden_size=hidden_size,
  282. intermediate_size=intermediate_size,
  283. nhead=nhead,
  284. dropout=dropout,
  285. )
  286. for _ in range(num_encoder_layers)
  287. ]
  288. )
  289. self.alignment = CrossAttentionLayer(
  290. hidden_size=hidden_size,
  291. intermediate_size=intermediate_size,
  292. dropout=dropout,
  293. )
  294. if alignment_position < 0:
  295. alignment_position = num_decoder_layers + alignment_position
  296. self.alignment_position = alignment_position
  297. assert 0 <= alignment_position < num_decoder_layers
  298. self.decoder = nn.ModuleList(
  299. [
  300. TransformerEncoderLayer(
  301. hidden_size=hidden_size,
  302. intermediate_size=intermediate_size,
  303. nhead=nhead,
  304. dropout=dropout,
  305. )
  306. for _ in range(num_decoder_layers)
  307. ]
  308. )
  309. self.alibi = AlibiPostionEmbedding(nhead, max_position)
  310. self.register_buffer(
  311. "causual_mask",
  312. torch.triu(torch.ones(max_position, max_position), diagonal=1).bool(),
  313. )
  314. self.max_batch_size = -1
  315. self.max_seq_length = -1
  316. def setup_kv_caches(self, max_batch_size, max_seq_length):
  317. if (
  318. self.max_seq_length >= max_seq_length
  319. and self.max_batch_size >= max_batch_size
  320. ):
  321. return
  322. if max_seq_length % 8 != 0:
  323. max_seq_length = max_seq_length + (8 - max_seq_length % 8)
  324. self.max_seq_length = max_seq_length
  325. self.max_batch_size = max_batch_size
  326. for b in self.decoder:
  327. b.attn.kv_cache = KVCache(
  328. max_batch_size, max_seq_length, b.attn.nhead, b.attn.head_dim
  329. )
  330. def forward(self, inputs, codes, input_mask=None, codes_mask=None):
  331. # x: (B, T)
  332. # y: (B, C, T)
  333. inputs = self.encoder_embedding(inputs)
  334. codes = rearrange(codes, "b c t -> c b t")
  335. codes = torch.stack(
  336. [emb(code) for emb, code in zip(self.decoder_embeddings, codes)], dim=0
  337. )
  338. codes = torch.mean(codes, dim=0) # (B, T)
  339. attn_bias = self.alibi(inputs)
  340. for layer in self.encoder:
  341. inputs = layer(inputs, attn_bias=attn_bias, key_padding_mask=input_mask)
  342. attn_bias = self.alibi(codes)
  343. causual_mask = self.causual_mask[: codes.shape[1], : codes.shape[1]]
  344. for idx, layer in enumerate(self.decoder):
  345. if idx == self.alignment_position:
  346. codes, _ = self.alignment(
  347. codes, inputs, memory_key_padding_mask=input_mask
  348. )
  349. codes = layer(
  350. codes,
  351. attn_bias=attn_bias,
  352. key_padding_mask=codes_mask,
  353. tgt_mask=causual_mask,
  354. )
  355. codes = self.decoder_head(codes)
  356. codes = rearrange(
  357. codes, "b t (c d) -> b c t d", c=self.num_codebooks, d=self.codebook_size
  358. )
  359. return codes
  360. def sample_decoder(
  361. self,
  362. x: torch.Tensor,
  363. context: torch.Tensor,
  364. input_pos: torch.Tensor,
  365. **sampling_kwargs,
  366. ):
  367. attn_bias = self.alibi.alibi[:, input_pos, : self.max_seq_length]
  368. causual_mask = self.causual_mask[input_pos, : self.max_seq_length]
  369. x = rearrange(x, "b c t -> c b t")
  370. x = torch.stack(
  371. [emb(code) for emb, code in zip(self.decoder_embeddings, x)], dim=0
  372. )
  373. x = torch.mean(x, dim=0) # (B, T)
  374. for idx, layer in enumerate(self.decoder):
  375. if idx == self.alignment_position:
  376. x, _ = self.alignment(x, context)
  377. x = layer(
  378. x, attn_bias=attn_bias, input_pos=input_pos, tgt_mask=causual_mask
  379. )
  380. x = self.decoder_head(x)
  381. x = rearrange(
  382. x, "b t (c d) -> b c t d", c=self.num_codebooks, d=self.codebook_size
  383. )
  384. # Never predict EOS or BOS for sub-codebooks
  385. x[:, 1:, :2] = -float("Inf")
  386. next_token, probs = [], []
  387. for i in range(self.num_codebooks):
  388. next_token_i, probs_i = self.sample(x[:, i], **sampling_kwargs)
  389. next_token.append(next_token_i)
  390. probs.append(probs_i)
  391. return torch.stack(next_token, dim=0), torch.stack(probs, dim=0)
  392. @staticmethod
  393. def multinomial_sample_one_no_sync(
  394. probs_sort,
  395. ): # Does multinomial sampling without a cuda synchronization
  396. q = torch.empty_like(probs_sort).exponential_(1)
  397. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  398. @staticmethod
  399. def logits_to_probs(
  400. logits,
  401. temperature: float = 1.0,
  402. top_p: Optional[int] = None,
  403. top_k: Optional[int] = None,
  404. ):
  405. if top_p is not None:
  406. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  407. cum_probs = torch.cumsum(
  408. torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
  409. )
  410. sorted_indices_to_remove = cum_probs > top_p
  411. sorted_indices_to_remove[0] = False # keep at least one option
  412. indices_to_remove = sorted_indices_to_remove.scatter(
  413. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  414. )
  415. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  416. logits = logits / max(temperature, 1e-5)
  417. if top_k is not None:
  418. v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
  419. pivot = v.select(-1, -1).unsqueeze(-1)
  420. logits = torch.where(logits < pivot, -float("Inf"), logits)
  421. probs = torch.nn.functional.softmax(logits, dim=-1)
  422. return probs
  423. def sample(
  424. self,
  425. logits,
  426. temperature: float = 1.0,
  427. top_p: Optional[int] = None,
  428. top_k: Optional[int] = None,
  429. ):
  430. probs = self.logits_to_probs(logits[0, -1], temperature, top_p, top_k)
  431. idx_next = self.multinomial_sample_one_no_sync(probs)
  432. return idx_next, probs
  433. def decode_n_tokens(
  434. self,
  435. cur_token: torch.Tensor,
  436. context: torch.Tensor,
  437. input_pos: torch.Tensor,
  438. num_new_tokens: int,
  439. callback=lambda _: _,
  440. **sampling_kwargs,
  441. ):
  442. new_tokens, new_probs = [], []
  443. # Sliding context window
  444. batch_size = 1
  445. back_map = torch.zeros(
  446. [batch_size, 1], device=cur_token.device, dtype=torch.long
  447. )
  448. for i in range(num_new_tokens):
  449. next_token, next_prob = self.sample_decoder(
  450. cur_token, context, input_pos, **sampling_kwargs
  451. )
  452. # index_map = torch.arange(6, device=cur_token.device)
  453. # index_map = back_map[:, -1:] + index_map.repeat(batch_size, 1)
  454. # add = torch.arange(batch_size, device=index_map.device).unsqueeze(1) #N, 1
  455. # index_map = index_map + add * t_length
  456. input_pos += 1
  457. new_tokens.append(next_token.clone())
  458. callback(new_tokens[-1])
  459. new_probs.append(next_prob.clone())
  460. if next_token[0, 0] == 1:
  461. break
  462. cur_token = next_token.view(1, self.num_codebooks, -1)
  463. return new_tokens, new_probs
  464. def compile(self):
  465. self.sampler_decoder = torch.compile(
  466. self.sample_decoder, mode="reduce-overhead", fullgraph=True
  467. )
  468. @torch.no_grad()
  469. def inference(self, inputs, prompt=None, max_new_tokens=1024, **sampling_kwargs):
  470. # inputs: (B, T)
  471. # prompt: (B, C, T)
  472. assert inputs.size(0) == 1, "Only support batch size 1 for now"
  473. if prompt is None:
  474. prompt = torch.tensor(
  475. [[[0]] * self.num_codebooks], device=inputs.device, dtype=torch.long
  476. )
  477. T = prompt.size(2)
  478. T_new = T + max_new_tokens
  479. # Encode Features
  480. inputs = self.encoder_embedding(inputs)
  481. attn_bias = self.alibi(inputs)
  482. for layer in self.encoder:
  483. inputs = layer(inputs, attn_bias=attn_bias)
  484. device, dtype = inputs.device, inputs.dtype
  485. # Decode
  486. with torch.device(inputs.device):
  487. self.setup_kv_caches(max_batch_size=1, max_seq_length=T_new)
  488. # create an empty tensor of the expected final shape and fill in the current tokens
  489. empty = torch.empty(
  490. (1, self.num_codebooks, T_new), dtype=torch.long, device=device
  491. )
  492. empty[:, :, :T] = prompt
  493. seq = empty
  494. input_pos = torch.arange(0, T, device=device)
  495. # prefill
  496. next_token, _ = self.sample_decoder(
  497. prompt.view(1, self.num_codebooks, -1), inputs, input_pos, **sampling_kwargs
  498. )
  499. seq[:, :, T] = next_token
  500. # create an empty tensor of the expected final shape and fill in the current tokens
  501. input_pos = torch.tensor([T], device=device, dtype=torch.long)
  502. generated_tokens, _ = self.decode_n_tokens(
  503. next_token.view(1, self.num_codebooks, -1),
  504. context=inputs,
  505. input_pos=input_pos,
  506. num_new_tokens=max_new_tokens - 1,
  507. **sampling_kwargs,
  508. )
  509. generated_tokens = torch.stack(generated_tokens, dim=-1)
  510. seq = seq[:, :, : T + 1 + generated_tokens.size(-1)]
  511. seq[:, :, T + 1 :] = generated_tokens
  512. return seq
  513. if __name__ == "__main__":
  514. # mha = MultiheadAttention(512, 8, dropout=0)
  515. # mha.eval()
  516. # mha.cuda()
  517. # q, k, v = torch.randn(3, 10, 16, 512)
  518. # q, k, v = q.cuda(), k.cuda(), v.cuda()
  519. # alibi = AlibiPostionEmbedding(8, 1024)
  520. # mha.bfloat16()
  521. # q, k, v = q.bfloat16(), k.bfloat16(), v.bfloat16()
  522. # bias = alibi(q).bfloat16()
  523. # # Causual mask
  524. # attn_mask = torch.triu(torch.ones(16, 16), diagonal=1).bool().cuda()
  525. # o, w = mha(q, k, v, return_weights=True, attn_bias=bias, attn_mask=attn_mask)
  526. # print(o.size())
  527. # print(w.size())
  528. # o1, w = mha(q, k, v, return_weights=False, attn_bias=bias, attn_mask=attn_mask)
  529. # print(o1.size())
  530. # print(o[0], o1.float()[0])
  531. # assert torch.allclose(o.float(), o1.float(), atol=1e-2, rtol=1e-2)
  532. # print("ok")
  533. # cross = CrossAttentionLayer(512, 1024, dropout=0)
  534. # cross.eval()
  535. # cross.cuda()
  536. # tgt = torch.randn(3, 10, 512).cuda()
  537. # memory = torch.randn(3, 20, 512).cuda()
  538. # o, w = cross(tgt, memory)
  539. # print(o.size())
  540. # print(w.size())
  541. # ten = TransformerEncoderLayer(512, 1024, 8, dropout=0)
  542. # ten.eval()
  543. # ten.cuda()
  544. # tgt = torch.randn(3, 10, 512).cuda()
  545. # o = ten(tgt)
  546. # print(o.size())
  547. trans = (
  548. FishSpeechTransformer(
  549. vocab_size=30000,
  550. codebook_size=120,
  551. num_codebooks=4,
  552. hidden_size=1024,
  553. intermediate_size=None,
  554. nhead=16,
  555. num_encoder_layers=12,
  556. num_decoder_layers=12,
  557. )
  558. .bfloat16()
  559. .cuda()
  560. )
  561. # Print n param
  562. print("Total params:", sum(i.numel() for i in trans.parameters()) / 1024 / 1024)
  563. inputs = torch.randint(0, 1000, (1, 16)).cuda()
  564. codes = torch.randint(0, 120, (1, 4, 128)).cuda()
  565. print(trans(inputs, codes).size())
  566. r = trans.inference(inputs, max_new_tokens=1024, top_k=5, temperature=0.3)
  567. print(r)