modules.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. import math
  2. import torch
  3. from einops import rearrange
  4. from torch import nn
  5. from torch.nn import functional as F
  6. try:
  7. from xformers.ops import memory_efficient_attention
  8. except ImportError as e:
  9. memory_efficient_attention = None
  10. # memory_efficient_attention = None
  11. class AlibiPostionEmbedding:
  12. def __init__(self, nheads, maxpos):
  13. context_position = torch.arange(maxpos)[:, None]
  14. memory_position = torch.arange(maxpos)[None, :]
  15. relative_position = memory_position - context_position
  16. relative_position = (
  17. torch.abs(relative_position).unsqueeze(0).expand(nheads, -1, -1)
  18. )
  19. self.slopes = torch.Tensor(self.get_slopes(nheads)) * -1
  20. self.alibi = self.slopes.unsqueeze(1).unsqueeze(1) * relative_position
  21. self.alibi = self.alibi.view(nheads, maxpos, maxpos)
  22. @staticmethod
  23. def get_slopes_power_of_2(n):
  24. start = 2 ** (-(2 ** -(math.log2(n) - 3)))
  25. ratio = start
  26. return [start * ratio**i for i in range(n)]
  27. def get_slopes(self, n):
  28. if math.log2(n).is_integer():
  29. return self.get_slopes_power_of_2(n)
  30. closest_power_of_2 = 2 ** math.floor(math.log2(n))
  31. return (
  32. self.get_slopes_power_of_2(closest_power_of_2)
  33. + self.get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
  34. )
  35. def __call__(self, x):
  36. # N, T, C
  37. return self.alibi[:, : x.size(1), : x.size(1)].to(x.device)
  38. class MultiheadAttention(nn.Module):
  39. def __init__(self, d_model, nhead, dropout=0.1):
  40. super().__init__()
  41. assert d_model % nhead == 0
  42. self.nhead = nhead
  43. self.d_model = d_model
  44. self.head_dim = d_model // nhead
  45. self.q_proj = nn.Linear(d_model, d_model)
  46. self.k_proj = nn.Linear(d_model, d_model)
  47. self.v_proj = nn.Linear(d_model, d_model)
  48. self.out_proj = nn.Linear(d_model, d_model)
  49. self.dropout = nn.Dropout(dropout)
  50. def forward(
  51. self,
  52. q,
  53. k,
  54. v,
  55. attn_mask=None,
  56. key_padding_mask=None,
  57. attn_bias=None,
  58. past_kv=None,
  59. return_weights=False,
  60. ):
  61. # (B, T, C)
  62. batch_size = q.size(0)
  63. q_length = q.size(1)
  64. k_length = k.size(1)
  65. if past_kv is not None:
  66. k, v = torch.cat([past_kv, k], 1), torch.cat([past_kv, v], 1)
  67. if attn_bias is not None:
  68. assert attn_bias.size() == (
  69. self.nhead,
  70. q_length,
  71. k_length,
  72. ), f"Should be {(self.nhead, q_length, k_length)}. Got {attn_bias.size()}"
  73. attn_bias = attn_bias.unsqueeze(0).expand(batch_size, -1, -1, -1)
  74. if attn_mask is not None:
  75. assert attn_mask.size() == (
  76. q_length,
  77. k_length,
  78. ), f"Should be {(q_length, k_length)}. Got {attn_mask.size()}"
  79. assert attn_mask.dtype == torch.bool
  80. attn_mask = attn_mask.unsqueeze(0).expand(batch_size * self.nhead, -1, -1)
  81. if key_padding_mask is not None:
  82. assert key_padding_mask.size() == (
  83. batch_size,
  84. k_length,
  85. ), f"Should be {(batch_size, k_length)}. Got {key_padding_mask.size()}"
  86. assert key_padding_mask.dtype == torch.bool
  87. key_padding_mask = (
  88. key_padding_mask.unsqueeze(1)
  89. .unsqueeze(1)
  90. .expand(-1, self.nhead, -1, -1)
  91. )
  92. key_padding_mask = key_padding_mask.reshape(
  93. batch_size * self.nhead, 1, k_length
  94. )
  95. if attn_mask is None:
  96. attn_mask = key_padding_mask.expand(-1, q.size(1), -1)
  97. else:
  98. attn_mask = attn_mask.logical_or(key_padding_mask)
  99. q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)
  100. if (
  101. return_weights is False
  102. and memory_efficient_attention is not None
  103. and q.device.type == "cuda"
  104. ):
  105. # (-> b, t,. n, d)
  106. q = rearrange(q, "b t (n d) -> b t n d", n=self.nhead)
  107. k = rearrange(k, "b t (n d) -> b t n d", n=self.nhead)
  108. v = rearrange(v, "b t (n d) -> b t n d", n=self.nhead)
  109. if attn_mask is not None:
  110. attn_mask = rearrange(attn_mask, "(b n) q k -> b n q k", n=self.nhead)
  111. attn_bias = attn_bias.masked_fill(attn_mask, float("-inf"))
  112. attn_output = memory_efficient_attention(
  113. q,
  114. k,
  115. v,
  116. attn_bias=attn_bias,
  117. scale=self.head_dim**-0.5,
  118. p=self.dropout.p,
  119. )
  120. attn_output = rearrange(attn_output, "b t n d -> b t (n d)", n=self.nhead)
  121. returned_weights = None
  122. else:
  123. q = rearrange(q, "b t (n d) -> (b n) t d", n=self.nhead)
  124. k = rearrange(k, "b t (n d) -> (b n) t d", n=self.nhead)
  125. v = rearrange(v, "b t (n d) -> (b n) t d", n=self.nhead)
  126. attn_weights = torch.bmm(q, k.mT) * (self.head_dim**-0.5)
  127. assert attn_weights.size() == (
  128. batch_size * self.nhead,
  129. q.size(1),
  130. k.size(1),
  131. )
  132. if attn_bias is not None:
  133. attn_bias = rearrange(attn_bias, "b n q k -> (b n) q k")
  134. attn_weights = attn_weights + attn_bias
  135. if attn_mask is not None:
  136. attn_weights = attn_weights.masked_fill(attn_mask, float("-inf"))
  137. attn_weights = F.softmax(attn_weights, dim=-1, dtype=attn_weights.dtype)
  138. returned_weights = attn_weights.view(
  139. batch_size, self.nhead, q.size(1), k.size(1)
  140. )
  141. attn_probs = self.dropout(attn_weights)
  142. attn_output = torch.bmm(attn_probs, v)
  143. attn_output = rearrange(attn_output, "(b n) t d -> b t (n d)", n=self.nhead)
  144. attn_output = self.out_proj(attn_output)
  145. return attn_output, returned_weights
  146. class GluMLP(nn.Module):
  147. def __init__(self, hidden_size=1024, intermediate_size=None, activation=nn.SiLU):
  148. super().__init__()
  149. if intermediate_size is None:
  150. intermediate_size = hidden_size * (11 / 3)
  151. intermediate_size = round(intermediate_size / 8) * 8
  152. self.hidden_size = hidden_size
  153. self.intermediate_size = intermediate_size
  154. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  155. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  156. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  157. self.act_fn = activation()
  158. def forward(self, x):
  159. return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  160. class RMSNorm(nn.Module):
  161. def __init__(self, hidden_size, eps=1e-6):
  162. """
  163. RMSNorm is equivalent to T5LayerNorm
  164. """
  165. super().__init__()
  166. self.weight = nn.Parameter(torch.ones(hidden_size))
  167. self.variance_epsilon = eps
  168. def forward(self, hidden_states):
  169. input_dtype = hidden_states.dtype
  170. hidden_states = hidden_states.to(torch.float32)
  171. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  172. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  173. return self.weight * hidden_states.to(input_dtype)
  174. class CrossAttentionLayer(nn.Module):
  175. def __init__(self, hidden_size=1024, intermediate_size=None, dropout=0.1):
  176. super().__init__()
  177. self.attn = MultiheadAttention(hidden_size, 1, dropout=dropout)
  178. self.mlp = GluMLP(hidden_size=hidden_size, intermediate_size=intermediate_size)
  179. self.input_layernorm_q = RMSNorm(hidden_size, eps=1e-6)
  180. self.input_layernorm_kv = RMSNorm(hidden_size, eps=1e-6)
  181. self.post_attention_layernorm = RMSNorm(hidden_size, eps=1e-6)
  182. def forward(self, tgt, memory, memory_key_padding_mask=None):
  183. residual = tgt
  184. tgt, memory = self.input_layernorm_q(tgt), self.input_layernorm_kv(memory)
  185. x, attn_weights = self.attn(
  186. tgt,
  187. memory,
  188. memory,
  189. key_padding_mask=memory_key_padding_mask,
  190. return_weights=True,
  191. )
  192. residual = x + residual
  193. x = self.post_attention_layernorm(residual)
  194. x = self.mlp(x)
  195. x = x + residual
  196. return x, attn_weights
  197. class TransformerEncoderLayer(nn.Module):
  198. def __init__(self, hidden_size=1024, intermediate_size=None, nhead=16, dropout=0.1):
  199. super().__init__()
  200. self.attn = MultiheadAttention(hidden_size, nhead, dropout=dropout)
  201. self.mlp = GluMLP(hidden_size=hidden_size, intermediate_size=intermediate_size)
  202. self.input_layernorm = RMSNorm(hidden_size, eps=1e-6)
  203. self.post_attention_layernorm = RMSNorm(hidden_size, eps=1e-6)
  204. def forward(self, x, attn_bias=None, key_padding_mask=None, tgt_mask=None):
  205. residual = x
  206. x = self.input_layernorm(x)
  207. x, _ = self.attn(
  208. x,
  209. x,
  210. x,
  211. attn_bias=attn_bias,
  212. key_padding_mask=key_padding_mask,
  213. attn_mask=tgt_mask,
  214. return_weights=False,
  215. )
  216. residual = x + residual
  217. x = self.post_attention_layernorm(residual)
  218. x = self.mlp(x)
  219. x = x + residual
  220. return x
  221. class FishSpeechTransformer(nn.Module):
  222. def __init__(
  223. self,
  224. vocab_size,
  225. codebook_size,
  226. num_codebooks,
  227. hidden_size=1024,
  228. intermediate_size=None,
  229. nhead=16,
  230. num_encoder_layers=12,
  231. num_decoder_layers=12,
  232. dropout=0.1,
  233. ):
  234. self.embedding = nn.Embedding(vocab_size, hidden_size)
  235. self.lm_head = nn.Linear(hidden_size, vocab_size * num_codebooks)
  236. if __name__ == "__main__":
  237. mha = MultiheadAttention(512, 8, dropout=0)
  238. mha.eval()
  239. mha.cuda()
  240. q, k, v = torch.randn(3, 10, 16, 512)
  241. q, k, v = q.cuda(), k.cuda(), v.cuda()
  242. alibi = AlibiPostionEmbedding(8, 1024)
  243. mha.bfloat16()
  244. q, k, v = q.bfloat16(), k.bfloat16(), v.bfloat16()
  245. bias = alibi(q).bfloat16()
  246. # Causual mask
  247. attn_mask = torch.triu(torch.ones(16, 16), diagonal=1).bool().cuda()
  248. o, w = mha(q, k, v, return_weights=True, attn_bias=bias, attn_mask=attn_mask)
  249. print(o.size())
  250. print(w.size())
  251. o1, w = mha(q, k, v, return_weights=False, attn_bias=bias, attn_mask=attn_mask)
  252. print(o1.size())
  253. print(o[0], o1.float()[0])
  254. assert torch.allclose(o.float(), o1.float(), atol=1e-2, rtol=1e-2)
  255. print("ok")
  256. cross = CrossAttentionLayer(512, 1024, dropout=0)
  257. cross.eval()
  258. cross.cuda()
  259. tgt = torch.randn(3, 10, 512).cuda()
  260. memory = torch.randn(3, 20, 512).cuda()
  261. o, w = cross(tgt, memory)
  262. print(o.size())
  263. print(w.size())
  264. ten = TransformerEncoderLayer(512, 1024, 8, dropout=0)
  265. ten.eval()
  266. ten.cuda()
  267. tgt = torch.randn(3, 10, 512).cuda()
  268. o = ten(tgt)
  269. print(o.size())