Quellcode durchsuchen

fix bug & implement generate

Lengyue vor 2 Jahren
Ursprung
Commit
9e1a9debfd
2 geänderte Dateien mit 54 neuen und 59 gelöschten Zeilen
  1. 51 45
      fish_speech/models/text2semantic/llama.py
  2. 3 14
      tools/llama/generate.py

+ 51 - 45
fish_speech/models/text2semantic/llama.py

@@ -148,7 +148,7 @@ class Transformer(nn.Module):
         self.max_seq_len = max_seq_len
         self.max_batch_size = max_batch_size
 
-        for b in self.layers:
+        for b in self.slow_layers:
             b.attention.kv_cache = KVCache(
                 max_batch_size,
                 max_seq_len,
@@ -157,6 +157,8 @@ class Transformer(nn.Module):
                 dtype=dtype,
             )
 
+        # TODO: add fast transformer kv cache
+
     def embed(self, x: Tensor) -> Tensor:
         # Here we want to merge the embeddings of the codebooks
         if self.config.num_codebooks == 0:
@@ -175,41 +177,6 @@ class Transformer(nn.Module):
 
         return x
 
-    def compute(
-        self,
-        x: Tensor,
-        freqs_cis: Tensor,
-        mask: Tensor,
-        input_pos: Optional[Tensor] = None,
-    ) -> TransformerForwardResult:
-        raise NotImplementedError
-
-        for layer in self.layers:
-            if self.config.use_gradient_checkpointing and self.training:
-                x = checkpoint(layer, x, freqs_cis, mask, input_pos, use_reentrant=True)
-            else:
-                x = layer(x, freqs_cis, mask, input_pos=input_pos)
-
-        x = self.norm(x)
-        logits = self.output(x)
-        token_logits = logits[:, :, : self.config.vocab_size]
-
-        if self.config.num_codebooks == 0:
-            return TransformerForwardResult(
-                token_logits=token_logits,
-                codebook_logits=None,
-            )
-
-        codebook_logits = logits[:, :, self.config.vocab_size :]
-        codebook_logits = rearrange(
-            codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
-        )
-
-        return TransformerForwardResult(
-            token_logits=token_logits,
-            codebook_logits=codebook_logits,
-        )
-
     def forward(
         self, x: Tensor, key_padding_mask: Optional[Tensor] = None
     ) -> TransformerForwardResult:
@@ -248,15 +215,9 @@ class Transformer(nn.Module):
             None, None, :fast_seq_len, :fast_seq_len
         ]  # (B, N, Q, K)
         fast_freqs_cis = self.freqs_cis[:fast_seq_len]
-
-        # There should be a bug here
-        # Say at t0, the given input is [/INST] for semantic token
-        # Then we want to predict <tok0>, <tok1>, ... (instead of <s> <s> <s>) given <feat>, <tok0>, <tok1>, ...
-        # Otherwise this becomes: decode tokens from same given tokens
-        # Ignore the last token, since the input should be <feat>, <tok0>, <tok1>, ...
         codebook_embeddings = self.fast_embeddings(codebooks[:, :-1])
 
-        x = torch.cat([x[:, None, 1:], codebook_embeddings], dim=1)  # (B, N + 1, S, D)
+        x = torch.cat([x[:, None], codebook_embeddings], dim=1)  # (B, N + 1, S, D)
         b, s = x.size(0), x.size(2)
         x = rearrange(x, "b n s d -> (b s) n d")  # flatten the batch and seq_len
 
@@ -298,9 +259,54 @@ class Transformer(nn.Module):
         ]  # (B, N, Q, K)
         freqs_cis = self.freqs_cis[input_pos]
 
-        # TODO: support key padding mask for generation
+        for layer in self.slow_layers:
+            x = layer(x, freqs_cis, mask, input_pos=input_pos)
+
+        # If prefill, we only calculate the logits of last token
+        if x.size(1) > 1:
+            x = x[:, -1:]
+
+        # We got slow_out here
+        slow_out = self.slow_norm(x)
+        token_logits = self.slow_output(slow_out)
+
+        # Fast transformer
+        fast_features = [x[:, None]]
+        fast_logits = []
+
+        for _ in range(self.config.num_codebooks):
+            x = torch.cat(fast_features, dim=1)  # (B, N + 1, S, D)
+            b, s = x.size(0), x.size(2)
+            x = rearrange(x, "b n s d -> (b s) n d")  # flatten the batch and seq_len
+
+            fast_seq_len = x.size(1)
+            fast_mask = self.causal_mask[
+                None, None, :fast_seq_len, :fast_seq_len
+            ]  # (B, N, Q, K)
+            fast_freqs_cis = self.freqs_cis[:fast_seq_len]
 
-        return self.compute(x, freqs_cis, mask, input_pos=input_pos)
+            for layer in self.fast_layers:
+                x = layer(x, fast_freqs_cis, fast_mask)
+
+            # unflatten the batch and num_codebooks
+            fast_out = self.fast_norm(x[:, -1:])  # only take the last token
+            codebook_logits = self.fast_output(fast_out)
+            fast_logits.append(codebook_logits)
+
+            # Get the argmax
+            codebook_idx = codebook_logits.argmax(dim=-1)
+            codebook_embeddings = self.fast_embeddings(codebook_idx)
+            fast_features.append(codebook_embeddings.view(b, 1, s, -1))
+
+        codebook_logits = torch.stack(fast_logits, dim=1)
+        assert codebook_logits.shape[1] == self.config.num_codebooks
+
+        codebook_logits = rearrange(codebook_logits, "b c n d -> b n c d")
+
+        return TransformerForwardResult(
+            token_logits=token_logits,
+            codebook_logits=codebook_logits,
+        )
 
 
 class TransformerBlock(nn.Module):

+ 3 - 14
tools/llama/generate.py

@@ -111,13 +111,7 @@ def decode_one_token(
     if model.config.num_codebooks != 0:
         for i in range(model.config.num_codebooks):
             codebooks.append(
-                sample(
-                    logits.codebook_logits[:, :, i],
-                    previous_tokens=previous_tokens[i + 1]
-                    if previous_tokens is not None
-                    else None,
-                    **sampling_kwargs,
-                )[0]
+                torch.argmax(logits.codebook_logits[:, :, i], dim=-1).view(1)
             )
 
     return torch.stack(codebooks, dim=0)
@@ -139,11 +133,7 @@ def prefill(
     if model.config.num_codebooks != 0:
         for i in range(model.config.num_codebooks):
             codebooks.append(
-                sample(
-                    logits.codebook_logits[:, :, i],
-                    previous_tokens=None,
-                    **sampling_kwargs,
-                )[0]
+                torch.argmax(logits.codebook_logits[:, :, i], dim=-1).view(1)
             )
 
     return torch.stack(codebooks, dim=0)
@@ -340,8 +330,7 @@ def load_model(config_name, checkpoint_path, device, precision):
     with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
         cfg = compose(config_name=config_name)
 
-    with torch.device("meta"):
-        model: Transformer = instantiate(cfg.model).model
+    model: Transformer = instantiate(cfg.model).model
 
     if "int8" in str(checkpoint_path):
         logger.info("Using int8 weight-only quantization!")