Ver código fonte

Optimize training speed by skip unused decode

Lengyue 2 anos atrás
pai
commit
6e67ebda73

+ 1 - 1
fish_speech/datasets/text.py

@@ -593,7 +593,7 @@ class TextDataCollator:
                     (0, max_tokens_length - tokens_length),
                     value=self.tokenizer.eos_token_id,
                 )
-                _tokens[1:, tokens_length:] = CODEBOOK_EOS_TOKEN_ID
+                _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
                 _labels = F.pad(
                     _labels, (0, max_tokens_length - _labels.size(1)), value=-100
                 )

+ 32 - 6
fish_speech/models/text2semantic/llama.py

@@ -192,9 +192,6 @@ class Transformer(nn.Module):
     ) -> TransformerForwardResult:
         # x: (batch, num_codebooks + 1, seq_len)
         seq_len = inp.size(2)
-
-        # For codebook, the decoding is actually shifted by 1
-        # Which  is the labels section
         codebooks = inp[:, 1:]
 
         # Here we want to merge the embeddings of the codebooks
@@ -228,13 +225,20 @@ class Transformer(nn.Module):
 
         # Drop the last token and rotate left
         codebooks = codebooks[:, :-1, 1:]
-        codebooks = F.pad(codebooks, (0, 1), value=self.config.codebook_padding_idx)
+        codebooks = F.pad(
+            codebooks, (0, 1, 1, 0), value=self.config.codebook_padding_idx
+        )
         codebook_embeddings = self.fast_embeddings(codebooks)
-
-        x = torch.cat([x[:, None], codebook_embeddings], dim=1)  # (B, N + 1, S, D)
+        x = codebook_embeddings + x[:, None]  # (B, N + 1, S, D)
         b, s = x.size(0), x.size(2)
         x = rearrange(x, "b n s d -> (b s) n d")  # flatten the batch and seq_len
 
+        # Remove padded part
+        codebooks = rearrange(codebooks, "b n s -> (b s) n")
+        codebook_mask = (codebooks == self.config.codebook_padding_idx).all(dim=-1)
+        x_bs, x_len = x.size(0), x.size(1)
+        x = x[~codebook_mask]
+
         for layer in self.fast_layers:
             if self.config.use_gradient_checkpointing and self.training:
                 x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
@@ -244,6 +248,12 @@ class Transformer(nn.Module):
         # unflatten the batch and num_codebooks
         fast_out = self.fast_norm(x)
         codebook_logits = self.fast_output(fast_out)
+
+        # Re-pad the codebook_logits
+        buffer = torch.zeros(x_bs, x_len, codebook_logits.size(-1), device=x.device)
+        buffer[~codebook_mask] = codebook_logits
+        codebook_logits = buffer
+
         assert codebook_logits.shape[1] == self.config.num_codebooks
         codebook_logits = rearrange(
             codebook_logits,
@@ -258,6 +268,22 @@ class Transformer(nn.Module):
             codebook_logits=codebook_logits,
         )
 
+    def forward_fast(self, x: Tensor) -> Tensor:
+        # Fast transformer
+        fast_seq_len = x.shape[1]
+        fast_mask = self.causal_mask[
+            None, None, :fast_seq_len, :fast_seq_len
+        ]  # (B, N, Q, K)
+        fast_freqs_cis = self.freqs_cis[:fast_seq_len]
+
+        for layer in self.fast_layers:
+            x = layer(x, fast_freqs_cis, fast_mask)
+
+        fast_out = self.fast_norm(x)
+        codebook_logits = self.fast_output(fast_out)
+
+        return codebook_logits
+
     def forward_generate_slow(
         self, x: Tensor, input_pos: Optional[Tensor] = None
     ) -> Tensor:

+ 29 - 2
tools/llama/generate.py

@@ -112,9 +112,11 @@ def decode_one_token(
         layer.attention.kv_cache.k_cache.fill_(0)
         layer.attention.kv_cache.v_cache.fill_(0)
 
+    buffer = [x.view(1, 1, -1)]
     for codebook_idx in range(model.config.num_codebooks):
         input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
         logits = model.forward_generate_fast(x, input_pos)
+        # print(x.shape, logits.shape)
         a = sample(
             logits,
             previous_tokens=(
@@ -126,6 +128,20 @@ def decode_one_token(
         )[0]
         x = model.fast_embeddings(a)
         codebooks.append(a)
+        # x = torch.cat(buffer, dim=1)
+        # logits = model.forward_fast(x)[:, -1:, :]
+        # a = sample(
+        #     logits,
+        #     previous_tokens=(
+        #         previous_tokens[codebook_idx + 1]
+        #         if previous_tokens is not None
+        #         else None
+        #     ),
+        #     **sampling_kwargs,
+        # )[0]
+        # x = model.fast_embeddings(a)
+        # codebooks.append(a)
+        # buffer.append(x.view(1, 1, -1))
 
     return torch.stack(codebooks, dim=0)
 
@@ -135,7 +151,7 @@ def prefill(
 ) -> torch.Tensor:
     # input_pos: [B, S]
     x, logits = model.forward_generate_slow(x, input_pos)
-    print("---", x.shape, logits.shape)
+
     codebooks = [
         sample(
             logits,
@@ -149,6 +165,7 @@ def prefill(
         layer.attention.kv_cache.k_cache.fill_(0)
         layer.attention.kv_cache.v_cache.fill_(0)
 
+    buffer = [x.view(1, 1, -1)]
     for codebook_idx in range(model.config.num_codebooks):
         input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
         logits = model.forward_generate_fast(x, input_pos)
@@ -160,6 +177,15 @@ def prefill(
         )[0]
         x = model.fast_embeddings(a)
         codebooks.append(a)
+        # x = torch.cat(buffer, dim=1)
+        # logits = model.forward_fast(x)[:, -1:, :]
+        # a = sample(
+        #     logits,
+        #     **sampling_kwargs,
+        # )[0]
+        # x = model.fast_embeddings(a)
+        # codebooks.append(a)
+        # buffer.append(x.view(1, 1, -1))
 
     return torch.stack(codebooks, dim=0)
 
@@ -211,6 +237,7 @@ def decode_n_tokens(
 
 
 @torch.no_grad()
+@torch.inference_mode()
 def generate(
     *,
     model: Transformer,
@@ -424,7 +451,7 @@ def split_text(text, min_length):
 @click.option("--num-samples", type=int, default=1)
 @click.option("--max-new-tokens", type=int, default=0)
 @click.option("--top-k", type=int, default=None)
-@click.option("--top-p", type=float, default=0.5)
+@click.option("--top-p", type=float, default=0.9)
 @click.option("--repetition-penalty", type=float, default=1.2)
 @click.option("--temperature", type=float, default=0.7)
 @click.option(