|
|
@@ -9,11 +9,12 @@ try:
|
|
|
from xformers.ops import memory_efficient_attention
|
|
|
except ImportError as e:
|
|
|
memory_efficient_attention = None
|
|
|
-# memory_efficient_attention = None
|
|
|
|
|
|
|
|
|
-class AlibiPostionEmbedding:
|
|
|
+class AlibiPostionEmbedding(nn.Module):
|
|
|
def __init__(self, nheads, maxpos):
|
|
|
+ super().__init__()
|
|
|
+
|
|
|
context_position = torch.arange(maxpos)[:, None]
|
|
|
memory_position = torch.arange(maxpos)[None, :]
|
|
|
relative_position = memory_position - context_position
|
|
|
@@ -21,8 +22,10 @@ class AlibiPostionEmbedding:
|
|
|
torch.abs(relative_position).unsqueeze(0).expand(nheads, -1, -1)
|
|
|
)
|
|
|
self.slopes = torch.Tensor(self.get_slopes(nheads)) * -1
|
|
|
- self.alibi = self.slopes.unsqueeze(1).unsqueeze(1) * relative_position
|
|
|
- self.alibi = self.alibi.view(nheads, maxpos, maxpos)
|
|
|
+ alibi = self.slopes.unsqueeze(1).unsqueeze(1) * relative_position
|
|
|
+ alibi = alibi.view(nheads, maxpos, maxpos)
|
|
|
+
|
|
|
+ self.register_buffer("alibi", alibi)
|
|
|
|
|
|
@staticmethod
|
|
|
def get_slopes_power_of_2(n):
|
|
|
@@ -128,8 +131,14 @@ class MultiheadAttention(nn.Module):
|
|
|
|
|
|
if attn_mask is not None:
|
|
|
attn_mask = rearrange(attn_mask, "(b n) q k -> b n q k", n=self.nhead)
|
|
|
+
|
|
|
+ if attn_bias is None:
|
|
|
+ attn_bias = torch.zeros_like(
|
|
|
+ attn_mask, dtype=q.dtype, device=q.device
|
|
|
+ )
|
|
|
attn_bias = attn_bias.masked_fill(attn_mask, float("-inf"))
|
|
|
|
|
|
+ attn_bias = attn_bias.to(q.dtype)
|
|
|
attn_output = memory_efficient_attention(
|
|
|
q,
|
|
|
k,
|
|
|
@@ -222,7 +231,12 @@ class CrossAttentionLayer(nn.Module):
|
|
|
self.input_layernorm_kv = RMSNorm(hidden_size, eps=1e-6)
|
|
|
self.post_attention_layernorm = RMSNorm(hidden_size, eps=1e-6)
|
|
|
|
|
|
- def forward(self, tgt, memory, memory_key_padding_mask=None):
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ tgt,
|
|
|
+ memory,
|
|
|
+ memory_key_padding_mask=None,
|
|
|
+ ):
|
|
|
residual = tgt
|
|
|
tgt, memory = self.input_layernorm_q(tgt), self.input_layernorm_kv(memory)
|
|
|
x, attn_weights = self.attn(
|
|
|
@@ -283,9 +297,97 @@ class FishSpeechTransformer(nn.Module):
|
|
|
num_encoder_layers=12,
|
|
|
num_decoder_layers=12,
|
|
|
dropout=0.1,
|
|
|
+ alignment_position=-2,
|
|
|
+ max_position=8192,
|
|
|
):
|
|
|
- self.embedding = nn.Embedding(vocab_size, hidden_size)
|
|
|
- self.lm_head = nn.Linear(hidden_size, vocab_size * num_codebooks)
|
|
|
+ super().__init__()
|
|
|
+
|
|
|
+ self.encoder_embedding = nn.Embedding(vocab_size, hidden_size)
|
|
|
+ self.decoder_embeddings = nn.ModuleList(
|
|
|
+ [nn.Embedding(codebook_size, hidden_size) for _ in range(num_codebooks)]
|
|
|
+ )
|
|
|
+ self.decoder_head = nn.Linear(hidden_size, codebook_size * num_codebooks)
|
|
|
+ self.codebook_size = codebook_size
|
|
|
+ self.num_codebooks = num_codebooks
|
|
|
+
|
|
|
+ self.encoder = nn.ModuleList(
|
|
|
+ [
|
|
|
+ TransformerEncoderLayer(
|
|
|
+ hidden_size=hidden_size,
|
|
|
+ intermediate_size=intermediate_size,
|
|
|
+ nhead=nhead,
|
|
|
+ dropout=dropout,
|
|
|
+ )
|
|
|
+ for _ in range(num_encoder_layers)
|
|
|
+ ]
|
|
|
+ )
|
|
|
+
|
|
|
+ self.alignment = CrossAttentionLayer(
|
|
|
+ hidden_size=hidden_size,
|
|
|
+ intermediate_size=intermediate_size,
|
|
|
+ dropout=dropout,
|
|
|
+ )
|
|
|
+
|
|
|
+ if alignment_position < 0:
|
|
|
+ alignment_position = num_decoder_layers + alignment_position
|
|
|
+
|
|
|
+ self.alignment_position = alignment_position
|
|
|
+ assert 0 <= alignment_position < num_decoder_layers
|
|
|
+
|
|
|
+ self.decoder = nn.ModuleList(
|
|
|
+ [
|
|
|
+ TransformerEncoderLayer(
|
|
|
+ hidden_size=hidden_size,
|
|
|
+ intermediate_size=intermediate_size,
|
|
|
+ nhead=nhead,
|
|
|
+ dropout=dropout,
|
|
|
+ )
|
|
|
+ for _ in range(num_decoder_layers)
|
|
|
+ ]
|
|
|
+ )
|
|
|
+
|
|
|
+ self.alibi = AlibiPostionEmbedding(nhead, max_position)
|
|
|
+ self.register_buffer(
|
|
|
+ "causual_mask",
|
|
|
+ torch.triu(torch.ones(max_position, max_position), diagonal=1).bool(),
|
|
|
+ )
|
|
|
+
|
|
|
+ def forward(self, inputs, codes, input_mask=None, codes_mask=None):
|
|
|
+ # x: (B, T)
|
|
|
+ # y: (B, C, T)
|
|
|
+ inputs = self.encoder_embedding(inputs)
|
|
|
+ codes = rearrange(codes, "b c t -> c b t")
|
|
|
+ codes = torch.stack(
|
|
|
+ [emb(code) for emb, code in zip(self.decoder_embeddings, codes)], dim=0
|
|
|
+ )
|
|
|
+ codes = torch.mean(codes, dim=0) # (B, T)
|
|
|
+
|
|
|
+ attn_bias = self.alibi(inputs)
|
|
|
+ for layer in self.encoder:
|
|
|
+ inputs = layer(inputs, attn_bias=attn_bias, key_padding_mask=input_mask)
|
|
|
+
|
|
|
+ attn_bias = self.alibi(codes)
|
|
|
+ causual_mask = self.causual_mask[: codes.shape[1], : codes.shape[1]]
|
|
|
+
|
|
|
+ for idx, layer in enumerate(self.decoder):
|
|
|
+ if idx == self.alignment_position:
|
|
|
+ codes, _ = self.alignment(
|
|
|
+ codes, inputs, memory_key_padding_mask=input_mask
|
|
|
+ )
|
|
|
+
|
|
|
+ codes = layer(
|
|
|
+ codes,
|
|
|
+ attn_bias=attn_bias,
|
|
|
+ key_padding_mask=codes_mask,
|
|
|
+ tgt_mask=causual_mask,
|
|
|
+ )
|
|
|
+
|
|
|
+ codes = self.decoder_head(codes)
|
|
|
+ codes = rearrange(
|
|
|
+ codes, "b t (c d) -> b c t d", c=self.num_codebooks, d=self.codebook_size
|
|
|
+ )
|
|
|
+
|
|
|
+ return codes
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
@@ -334,3 +436,23 @@ if __name__ == "__main__":
|
|
|
tgt = torch.randn(3, 10, 512).cuda()
|
|
|
o = ten(tgt)
|
|
|
print(o.size())
|
|
|
+
|
|
|
+ trans = (
|
|
|
+ FishSpeechTransformer(
|
|
|
+ vocab_size=30000,
|
|
|
+ codebook_size=120,
|
|
|
+ num_codebooks=4,
|
|
|
+ hidden_size=1024,
|
|
|
+ intermediate_size=None,
|
|
|
+ nhead=16,
|
|
|
+ num_encoder_layers=12,
|
|
|
+ num_decoder_layers=12,
|
|
|
+ )
|
|
|
+ .bfloat16()
|
|
|
+ .cuda()
|
|
|
+ )
|
|
|
+ # Print n param
|
|
|
+ print("Total params:", sum(i.numel() for i in trans.parameters()) / 1024 / 1024)
|
|
|
+ inputs = torch.randint(0, 1000, (3, 16)).cuda()
|
|
|
+ codes = torch.randint(0, 120, (3, 4, 128)).cuda()
|
|
|
+ print(trans(inputs, codes).size())
|