Lengyue 2 лет назад
Родитель
Сommit
0a1986bb14

+ 2 - 1
fish_speech/configs/text2semantic_pretrain_large.yaml

@@ -8,6 +8,7 @@ project: text2semantic_pretrain_large_dual_ar
 model:
 model:
   model:
   model:
     config:
     config:
-      n_layer: 36
+      n_slow_layer: 36
+      n_fast_layer: 8
       n_head: 20
       n_head: 20
       dim: 1280
       dim: 1280

+ 2 - 1
fish_speech/configs/text2semantic_pretrain_medium.yaml

@@ -8,6 +8,7 @@ project: text2semantic_pretrain_medium_dual_ar
 model:
 model:
   model:
   model:
     config:
     config:
-      n_layer: 24
+      n_slow_layer: 24
+      n_fast_layer: 6
       n_head: 16
       n_head: 16
       dim: 1024
       dim: 1024

+ 2 - 7
fish_speech/models/text2semantic/lit_module.py

@@ -165,7 +165,7 @@ class TextToSemantic(L.LightningModule):
         # Do positive and negative samples in the same batch to speed up training
         # Do positive and negative samples in the same batch to speed up training
         labels = batch["labels"]
         labels = batch["labels"]
         outputs = self.model(
         outputs = self.model(
-            x=batch["inputs"],
+            inp=batch["inputs"],
             key_padding_mask=batch["attention_masks"],
             key_padding_mask=batch["attention_masks"],
         )
         )
         token_logits = outputs.token_logits
         token_logits = outputs.token_logits
@@ -186,12 +186,7 @@ class TextToSemantic(L.LightningModule):
 
 
         # If we have a codebook, add the loss
         # If we have a codebook, add the loss
         if self.model.config.num_codebooks != 0:
         if self.model.config.num_codebooks != 0:
-            # We want to shift the labels by one to the right
-            codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks, :-1]
-            codebook_labels = torch.nn.functional.pad(
-                codebook_labels, (1, 0), value=-100
-            ).mT
-
+            codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
             semantic_loss = F.cross_entropy(
             semantic_loss = F.cross_entropy(
                 codebook_logits.reshape(-1, codebook_logits.size(-1)),
                 codebook_logits.reshape(-1, codebook_logits.size(-1)),
                 codebook_labels.reshape(-1),
                 codebook_labels.reshape(-1),

+ 9 - 5
fish_speech/models/text2semantic/llama.py

@@ -188,17 +188,17 @@ class Transformer(nn.Module):
         return x
         return x
 
 
     def forward(
     def forward(
-        self, x: Tensor, key_padding_mask: Optional[Tensor] = None
+        self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
     ) -> TransformerForwardResult:
     ) -> TransformerForwardResult:
         # x: (batch, num_codebooks + 1, seq_len)
         # x: (batch, num_codebooks + 1, seq_len)
-        seq_len = x.size(2)
+        seq_len = inp.size(2)
 
 
         # For codebook, the decoding is actually shifted by 1
         # For codebook, the decoding is actually shifted by 1
         # Which  is the labels section
         # Which  is the labels section
-        codebooks = x[:, 1:]
+        codebooks = inp[:, 1:]
 
 
         # Here we want to merge the embeddings of the codebooks
         # Here we want to merge the embeddings of the codebooks
-        x = self.embed(x)
+        x = self.embed(inp)
 
 
         mask = self.causal_mask[None, None, :seq_len, :seq_len]  # (B, N, Q, K)
         mask = self.causal_mask[None, None, :seq_len, :seq_len]  # (B, N, Q, K)
         freqs_cis = self.freqs_cis[:seq_len]
         freqs_cis = self.freqs_cis[:seq_len]
@@ -225,7 +225,11 @@ class Transformer(nn.Module):
             None, None, :fast_seq_len, :fast_seq_len
             None, None, :fast_seq_len, :fast_seq_len
         ]  # (B, N, Q, K)
         ]  # (B, N, Q, K)
         fast_freqs_cis = self.freqs_cis[:fast_seq_len]
         fast_freqs_cis = self.freqs_cis[:fast_seq_len]
-        codebook_embeddings = self.fast_embeddings(codebooks[:, :-1])
+
+        # Drop the last token and rotate left
+        codebooks = codebooks[:, :-1, 1:]
+        codebooks = F.pad(codebooks, (0, 1), value=self.config.codebook_padding_idx)
+        codebook_embeddings = self.fast_embeddings(codebooks)
 
 
         x = torch.cat([x[:, None], codebook_embeddings], dim=1)  # (B, N + 1, S, D)
         x = torch.cat([x[:, None], codebook_embeddings], dim=1)  # (B, N + 1, S, D)
         b, s = x.size(0), x.size(2)
         b, s = x.size(0), x.size(2)