modules.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. import math
  2. import torch
  3. from einops import rearrange
  4. from torch import nn
  5. from torch.nn import functional as F
  6. try:
  7. from xformers.ops import memory_efficient_attention
  8. except ImportError as e:
  9. memory_efficient_attention = None
  10. class AlibiPostionEmbedding(nn.Module):
  11. def __init__(self, nheads, maxpos):
  12. super().__init__()
  13. context_position = torch.arange(maxpos)[:, None]
  14. memory_position = torch.arange(maxpos)[None, :]
  15. relative_position = memory_position - context_position
  16. relative_position = (
  17. torch.abs(relative_position).unsqueeze(0).expand(nheads, -1, -1)
  18. )
  19. self.slopes = torch.Tensor(self.get_slopes(nheads)) * -1
  20. alibi = self.slopes.unsqueeze(1).unsqueeze(1) * relative_position
  21. alibi = alibi.view(nheads, maxpos, maxpos)
  22. self.register_buffer("alibi", alibi)
  23. @staticmethod
  24. def get_slopes_power_of_2(n):
  25. start = 2 ** (-(2 ** -(math.log2(n) - 3)))
  26. ratio = start
  27. return [start * ratio**i for i in range(n)]
  28. def get_slopes(self, n):
  29. if math.log2(n).is_integer():
  30. return self.get_slopes_power_of_2(n)
  31. closest_power_of_2 = 2 ** math.floor(math.log2(n))
  32. return (
  33. self.get_slopes_power_of_2(closest_power_of_2)
  34. + self.get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
  35. )
  36. def __call__(self, x):
  37. # N, T, C
  38. return self.alibi[:, : x.size(1), : x.size(1)].to(x.device)
  39. class MultiheadAttention(nn.Module):
  40. def __init__(self, d_model, nhead, dropout=0.1):
  41. super().__init__()
  42. assert d_model % nhead == 0
  43. self.nhead = nhead
  44. self.d_model = d_model
  45. self.head_dim = d_model // nhead
  46. self.q_proj = nn.Linear(d_model, d_model)
  47. self.k_proj = nn.Linear(d_model, d_model)
  48. self.v_proj = nn.Linear(d_model, d_model)
  49. self.out_proj = nn.Linear(d_model, d_model)
  50. self.dropout = nn.Dropout(dropout)
  51. def forward(
  52. self,
  53. q,
  54. k,
  55. v,
  56. attn_mask=None,
  57. key_padding_mask=None,
  58. attn_bias=None,
  59. past_kv=None,
  60. return_weights=False,
  61. ):
  62. # (B, T, C)
  63. batch_size = q.size(0)
  64. q_length = q.size(1)
  65. k_length = k.size(1)
  66. if past_kv is not None:
  67. k, v = torch.cat([past_kv, k], 1), torch.cat([past_kv, v], 1)
  68. if attn_bias is not None:
  69. assert attn_bias.size() == (
  70. self.nhead,
  71. q_length,
  72. k_length,
  73. ), f"Should be {(self.nhead, q_length, k_length)}. Got {attn_bias.size()}"
  74. attn_bias = attn_bias.unsqueeze(0).expand(batch_size, -1, -1, -1)
  75. if attn_mask is not None:
  76. assert attn_mask.size() == (
  77. q_length,
  78. k_length,
  79. ), f"Should be {(q_length, k_length)}. Got {attn_mask.size()}"
  80. assert attn_mask.dtype == torch.bool
  81. attn_mask = attn_mask.unsqueeze(0).expand(batch_size * self.nhead, -1, -1)
  82. if key_padding_mask is not None:
  83. assert key_padding_mask.size() == (
  84. batch_size,
  85. k_length,
  86. ), f"Should be {(batch_size, k_length)}. Got {key_padding_mask.size()}"
  87. assert key_padding_mask.dtype == torch.bool
  88. key_padding_mask = (
  89. key_padding_mask.unsqueeze(1)
  90. .unsqueeze(1)
  91. .expand(-1, self.nhead, -1, -1)
  92. )
  93. key_padding_mask = key_padding_mask.reshape(
  94. batch_size * self.nhead, 1, k_length
  95. )
  96. if attn_mask is None:
  97. attn_mask = key_padding_mask.expand(-1, q.size(1), -1)
  98. else:
  99. attn_mask = attn_mask.logical_or(key_padding_mask)
  100. q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)
  101. if (
  102. return_weights is False
  103. and memory_efficient_attention is not None
  104. and q.device.type == "cuda"
  105. ):
  106. # (-> b, t,. n, d)
  107. q = rearrange(q, "b t (n d) -> b t n d", n=self.nhead)
  108. k = rearrange(k, "b t (n d) -> b t n d", n=self.nhead)
  109. v = rearrange(v, "b t (n d) -> b t n d", n=self.nhead)
  110. if attn_mask is not None:
  111. attn_mask = rearrange(attn_mask, "(b n) q k -> b n q k", n=self.nhead)
  112. if attn_bias is None:
  113. attn_bias = torch.zeros_like(
  114. attn_mask, dtype=q.dtype, device=q.device
  115. )
  116. attn_bias = attn_bias.masked_fill(attn_mask, float("-inf"))
  117. attn_bias = attn_bias.to(q.dtype)
  118. attn_output = memory_efficient_attention(
  119. q,
  120. k,
  121. v,
  122. attn_bias=attn_bias,
  123. scale=self.head_dim**-0.5,
  124. p=self.dropout.p,
  125. )
  126. attn_output = rearrange(attn_output, "b t n d -> b t (n d)", n=self.nhead)
  127. returned_weights = None
  128. else:
  129. q = rearrange(q, "b t (n d) -> (b n) t d", n=self.nhead)
  130. k = rearrange(k, "b t (n d) -> (b n) t d", n=self.nhead)
  131. v = rearrange(v, "b t (n d) -> (b n) t d", n=self.nhead)
  132. attn_weights = torch.bmm(q, k.mT) * (self.head_dim**-0.5)
  133. assert attn_weights.size() == (
  134. batch_size * self.nhead,
  135. q.size(1),
  136. k.size(1),
  137. )
  138. if attn_bias is not None:
  139. attn_bias = rearrange(attn_bias, "b n q k -> (b n) q k")
  140. attn_weights = attn_weights + attn_bias
  141. if attn_mask is not None:
  142. attn_weights = attn_weights.masked_fill(attn_mask, float("-inf"))
  143. attn_weights = F.softmax(attn_weights, dim=-1, dtype=attn_weights.dtype)
  144. returned_weights = attn_weights.view(
  145. batch_size, self.nhead, q.size(1), k.size(1)
  146. )
  147. attn_probs = self.dropout(attn_weights)
  148. attn_output = torch.bmm(attn_probs, v)
  149. attn_output = rearrange(attn_output, "(b n) t d -> b t (n d)", n=self.nhead)
  150. attn_output = self.out_proj(attn_output)
  151. return attn_output, returned_weights
  152. class GluMLP(nn.Module):
  153. def __init__(self, hidden_size=1024, intermediate_size=None, activation=nn.SiLU):
  154. super().__init__()
  155. if intermediate_size is None:
  156. intermediate_size = hidden_size * (11 / 3)
  157. intermediate_size = round(intermediate_size / 8) * 8
  158. self.hidden_size = hidden_size
  159. self.intermediate_size = intermediate_size
  160. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  161. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  162. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  163. self.act_fn = activation()
  164. def forward(self, x):
  165. return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  166. class RMSNorm(nn.Module):
  167. def __init__(self, hidden_size, eps=1e-6):
  168. """
  169. RMSNorm is equivalent to T5LayerNorm
  170. """
  171. super().__init__()
  172. self.weight = nn.Parameter(torch.ones(hidden_size))
  173. self.variance_epsilon = eps
  174. def forward(self, hidden_states):
  175. input_dtype = hidden_states.dtype
  176. hidden_states = hidden_states.to(torch.float32)
  177. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  178. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  179. return self.weight * hidden_states.to(input_dtype)
  180. class CrossAttentionLayer(nn.Module):
  181. def __init__(self, hidden_size=1024, intermediate_size=None, dropout=0.1):
  182. super().__init__()
  183. self.attn = MultiheadAttention(hidden_size, 1, dropout=dropout)
  184. self.mlp = GluMLP(hidden_size=hidden_size, intermediate_size=intermediate_size)
  185. self.input_layernorm_q = RMSNorm(hidden_size, eps=1e-6)
  186. self.input_layernorm_kv = RMSNorm(hidden_size, eps=1e-6)
  187. self.post_attention_layernorm = RMSNorm(hidden_size, eps=1e-6)
  188. def forward(
  189. self,
  190. tgt,
  191. memory,
  192. memory_key_padding_mask=None,
  193. ):
  194. residual = tgt
  195. tgt, memory = self.input_layernorm_q(tgt), self.input_layernorm_kv(memory)
  196. x, attn_weights = self.attn(
  197. tgt,
  198. memory,
  199. memory,
  200. key_padding_mask=memory_key_padding_mask,
  201. return_weights=True,
  202. )
  203. residual = x + residual
  204. x = self.post_attention_layernorm(residual)
  205. x = self.mlp(x)
  206. x = x + residual
  207. return x, attn_weights
  208. class TransformerEncoderLayer(nn.Module):
  209. def __init__(self, hidden_size=1024, intermediate_size=None, nhead=16, dropout=0.1):
  210. super().__init__()
  211. self.attn = MultiheadAttention(hidden_size, nhead, dropout=dropout)
  212. self.mlp = GluMLP(hidden_size=hidden_size, intermediate_size=intermediate_size)
  213. self.input_layernorm = RMSNorm(hidden_size, eps=1e-6)
  214. self.post_attention_layernorm = RMSNorm(hidden_size, eps=1e-6)
  215. def forward(self, x, attn_bias=None, key_padding_mask=None, tgt_mask=None):
  216. residual = x
  217. x = self.input_layernorm(x)
  218. x, _ = self.attn(
  219. x,
  220. x,
  221. x,
  222. attn_bias=attn_bias,
  223. key_padding_mask=key_padding_mask,
  224. attn_mask=tgt_mask,
  225. return_weights=False,
  226. )
  227. residual = x + residual
  228. x = self.post_attention_layernorm(residual)
  229. x = self.mlp(x)
  230. x = x + residual
  231. return x
  232. class FishSpeechTransformer(nn.Module):
  233. def __init__(
  234. self,
  235. vocab_size,
  236. codebook_size,
  237. num_codebooks,
  238. hidden_size=1024,
  239. intermediate_size=None,
  240. nhead=16,
  241. num_encoder_layers=12,
  242. num_decoder_layers=12,
  243. dropout=0.1,
  244. alignment_position=-2,
  245. max_position=8192,
  246. ):
  247. super().__init__()
  248. self.encoder_embedding = nn.Embedding(vocab_size, hidden_size)
  249. self.decoder_embeddings = nn.ModuleList(
  250. [nn.Embedding(codebook_size, hidden_size) for _ in range(num_codebooks)]
  251. )
  252. self.decoder_head = nn.Linear(hidden_size, codebook_size * num_codebooks)
  253. self.codebook_size = codebook_size
  254. self.num_codebooks = num_codebooks
  255. self.encoder = nn.ModuleList(
  256. [
  257. TransformerEncoderLayer(
  258. hidden_size=hidden_size,
  259. intermediate_size=intermediate_size,
  260. nhead=nhead,
  261. dropout=dropout,
  262. )
  263. for _ in range(num_encoder_layers)
  264. ]
  265. )
  266. self.alignment = CrossAttentionLayer(
  267. hidden_size=hidden_size,
  268. intermediate_size=intermediate_size,
  269. dropout=dropout,
  270. )
  271. if alignment_position < 0:
  272. alignment_position = num_decoder_layers + alignment_position
  273. self.alignment_position = alignment_position
  274. assert 0 <= alignment_position < num_decoder_layers
  275. self.decoder = nn.ModuleList(
  276. [
  277. TransformerEncoderLayer(
  278. hidden_size=hidden_size,
  279. intermediate_size=intermediate_size,
  280. nhead=nhead,
  281. dropout=dropout,
  282. )
  283. for _ in range(num_decoder_layers)
  284. ]
  285. )
  286. self.alibi = AlibiPostionEmbedding(nhead, max_position)
  287. self.register_buffer(
  288. "causual_mask",
  289. torch.triu(torch.ones(max_position, max_position), diagonal=1).bool(),
  290. )
  291. def forward(self, inputs, codes, input_mask=None, codes_mask=None):
  292. # x: (B, T)
  293. # y: (B, C, T)
  294. inputs = self.encoder_embedding(inputs)
  295. codes = rearrange(codes, "b c t -> c b t")
  296. codes = torch.stack(
  297. [emb(code) for emb, code in zip(self.decoder_embeddings, codes)], dim=0
  298. )
  299. codes = torch.mean(codes, dim=0) # (B, T)
  300. attn_bias = self.alibi(inputs)
  301. for layer in self.encoder:
  302. inputs = layer(inputs, attn_bias=attn_bias, key_padding_mask=input_mask)
  303. attn_bias = self.alibi(codes)
  304. causual_mask = self.causual_mask[: codes.shape[1], : codes.shape[1]]
  305. for idx, layer in enumerate(self.decoder):
  306. if idx == self.alignment_position:
  307. codes, _ = self.alignment(
  308. codes, inputs, memory_key_padding_mask=input_mask
  309. )
  310. codes = layer(
  311. codes,
  312. attn_bias=attn_bias,
  313. key_padding_mask=codes_mask,
  314. tgt_mask=causual_mask,
  315. )
  316. codes = self.decoder_head(codes)
  317. codes = rearrange(
  318. codes, "b t (c d) -> b c t d", c=self.num_codebooks, d=self.codebook_size
  319. )
  320. return codes
  321. if __name__ == "__main__":
  322. mha = MultiheadAttention(512, 8, dropout=0)
  323. mha.eval()
  324. mha.cuda()
  325. q, k, v = torch.randn(3, 10, 16, 512)
  326. q, k, v = q.cuda(), k.cuda(), v.cuda()
  327. alibi = AlibiPostionEmbedding(8, 1024)
  328. mha.bfloat16()
  329. q, k, v = q.bfloat16(), k.bfloat16(), v.bfloat16()
  330. bias = alibi(q).bfloat16()
  331. # Causual mask
  332. attn_mask = torch.triu(torch.ones(16, 16), diagonal=1).bool().cuda()
  333. o, w = mha(q, k, v, return_weights=True, attn_bias=bias, attn_mask=attn_mask)
  334. print(o.size())
  335. print(w.size())
  336. o1, w = mha(q, k, v, return_weights=False, attn_bias=bias, attn_mask=attn_mask)
  337. print(o1.size())
  338. print(o[0], o1.float()[0])
  339. assert torch.allclose(o.float(), o1.float(), atol=1e-2, rtol=1e-2)
  340. print("ok")
  341. cross = CrossAttentionLayer(512, 1024, dropout=0)
  342. cross.eval()
  343. cross.cuda()
  344. tgt = torch.randn(3, 10, 512).cuda()
  345. memory = torch.randn(3, 20, 512).cuda()
  346. o, w = cross(tgt, memory)
  347. print(o.size())
  348. print(w.size())
  349. ten = TransformerEncoderLayer(512, 1024, 8, dropout=0)
  350. ten.eval()
  351. ten.cuda()
  352. tgt = torch.randn(3, 10, 512).cuda()
  353. o = ten(tgt)
  354. print(o.size())
  355. trans = (
  356. FishSpeechTransformer(
  357. vocab_size=30000,
  358. codebook_size=120,
  359. num_codebooks=4,
  360. hidden_size=1024,
  361. intermediate_size=None,
  362. nhead=16,
  363. num_encoder_layers=12,
  364. num_decoder_layers=12,
  365. )
  366. .bfloat16()
  367. .cuda()
  368. )
  369. # Print n param
  370. print("Total params:", sum(i.numel() for i in trans.parameters()) / 1024 / 1024)
  371. inputs = torch.randint(0, 1000, (3, 16)).cuda()
  372. codes = torch.randint(0, 120, (3, 4, 128)).cuda()
  373. print(trans(inputs, codes).size())