Browse Source

Rollback decoding

Lengyue 2 years ago
parent
commit
87ca8cf62e
2 changed files with 2 additions and 22 deletions
  1. 2 20
      fish_speech/models/text2semantic/llama.py
  2. 0 2
      tools/llama/generate.py

+ 2 - 20
fish_speech/models/text2semantic/llama.py

@@ -225,11 +225,9 @@ class Transformer(nn.Module):
 
         # Drop the last token and rotate left
         codebooks = codebooks[:, :-1, 1:]
-        codebooks = F.pad(
-            codebooks, (0, 1, 1, 0), value=self.config.codebook_padding_idx
-        )
+        codebooks = F.pad(codebooks, (0, 1), value=self.config.codebook_padding_idx)
         codebook_embeddings = self.fast_embeddings(codebooks)
-        x = codebook_embeddings + x[:, None]  # (B, N + 1, S, D)
+        x = torch.cat([x[:, None], codebook_embeddings], dim=1)
         b, s = x.size(0), x.size(2)
         x = rearrange(x, "b n s d -> (b s) n d")  # flatten the batch and seq_len
 
@@ -268,22 +266,6 @@ class Transformer(nn.Module):
             codebook_logits=codebook_logits,
         )
 
-    def forward_fast(self, x: Tensor) -> Tensor:
-        # Fast transformer
-        fast_seq_len = x.shape[1]
-        fast_mask = self.causal_mask[
-            None, None, :fast_seq_len, :fast_seq_len
-        ]  # (B, N, Q, K)
-        fast_freqs_cis = self.freqs_cis[:fast_seq_len]
-
-        for layer in self.fast_layers:
-            x = layer(x, fast_freqs_cis, fast_mask)
-
-        fast_out = self.fast_norm(x)
-        codebook_logits = self.fast_output(fast_out)
-
-        return codebook_logits
-
     def forward_generate_slow(
         self, x: Tensor, input_pos: Optional[Tensor] = None
     ) -> Tensor:

+ 0 - 2
tools/llama/generate.py

@@ -112,7 +112,6 @@ def decode_one_token(
         layer.attention.kv_cache.k_cache.fill_(0)
         layer.attention.kv_cache.v_cache.fill_(0)
 
-    buffer = [x.view(1, 1, -1)]
     for codebook_idx in range(model.config.num_codebooks):
         input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
         logits = model.forward_generate_fast(x, input_pos)
@@ -165,7 +164,6 @@ def prefill(
         layer.attention.kv_cache.k_cache.fill_(0)
         layer.attention.kv_cache.v_cache.fill_(0)
 
-    buffer = [x.view(1, 1, -1)]
     for codebook_idx in range(model.config.num_codebooks):
         input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
         logits = model.forward_generate_fast(x, input_pos)