Quellcode durchsuchen

Update dual ar generate & config

Lengyue vor 2 Jahren
Ursprung
Commit
362a1a7116

+ 1 - 1
fish_speech/configs/text2semantic_pretrain_large.yaml

@@ -2,7 +2,7 @@ defaults:
   - text2semantic_pretrain_small
   - _self_
 
-project: text2semantic_pretrain_large_4_in_8_codebooks
+project: text2semantic_pretrain_large_dual_ar
 
 # Model Configuration
 model:

+ 1 - 1
fish_speech/configs/text2semantic_pretrain_medium.yaml

@@ -2,7 +2,7 @@ defaults:
   - text2semantic_pretrain_small
   - _self_
 
-project: text2semantic_pretrain_medium_4_in_8_codebooks
+project: text2semantic_pretrain_medium_dual_ar
 
 # Model Configuration
 model:

+ 30 - 34
fish_speech/models/text2semantic/llama.py

@@ -148,6 +148,7 @@ class Transformer(nn.Module):
         self.max_seq_len = max_seq_len
         self.max_batch_size = max_batch_size
 
+        # Slow transformer
         for b in self.slow_layers:
             b.attention.kv_cache = KVCache(
                 max_batch_size,
@@ -157,7 +158,16 @@ class Transformer(nn.Module):
                 dtype=dtype,
             )
 
-        # TODO: add fast transformer kv cache
+        # Fast transformer
+        # The max seq len here is the number of codebooks
+        for b in self.fast_layers:
+            b.attention.kv_cache = KVCache(
+                max_batch_size,
+                self.config.num_codebooks,
+                self.config.n_local_heads,
+                head_dim,
+                dtype=dtype,
+            )
 
     def embed(self, x: Tensor) -> Tensor:
         # Here we want to merge the embeddings of the codebooks
@@ -244,7 +254,9 @@ class Transformer(nn.Module):
             codebook_logits=codebook_logits,
         )
 
-    def forward_generate(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
+    def forward_generate_slow(
+        self, x: Tensor, input_pos: Optional[Tensor] = None
+    ) -> Tensor:
         ### TODO: fix this
         # x: (batch, num_codebooks + 1, 1)
 
@@ -270,43 +282,27 @@ class Transformer(nn.Module):
         slow_out = self.slow_norm(x)
         token_logits = self.slow_output(slow_out)
 
-        # Fast transformer
-        fast_features = [x[:, None]]
-        fast_logits = []
-
-        for _ in range(self.config.num_codebooks):
-            x = torch.cat(fast_features, dim=1)  # (B, N + 1, S, D)
-            b, s = x.size(0), x.size(2)
-            x = rearrange(x, "b n s d -> (b s) n d")  # flatten the batch and seq_len
-
-            fast_seq_len = x.size(1)
-            fast_mask = self.causal_mask[
-                None, None, :fast_seq_len, :fast_seq_len
-            ]  # (B, N, Q, K)
-            fast_freqs_cis = self.freqs_cis[:fast_seq_len]
-
-            for layer in self.fast_layers:
-                x = layer(x, fast_freqs_cis, fast_mask)
+        return x, token_logits
 
-            # unflatten the batch and num_codebooks
-            fast_out = self.fast_norm(x[:, -1:])  # only take the last token
-            codebook_logits = self.fast_output(fast_out)
-            fast_logits.append(codebook_logits)
+    def forward_generate_fast(
+        self, x: Tensor, input_pos: Optional[Tensor] = None
+    ) -> Tensor:
+        # Fast transformer
+        x = x.view(1, 1, -1)
 
-            # Get the argmax
-            codebook_idx = codebook_logits.argmax(dim=-1)
-            codebook_embeddings = self.fast_embeddings(codebook_idx)
-            fast_features.append(codebook_embeddings.view(b, 1, s, -1))
+        fast_mask = self.causal_mask[
+            None, None, input_pos, : self.config.num_codebooks
+        ]  # (B, N, Q, K)
+        fast_freqs_cis = self.freqs_cis[input_pos]
 
-        codebook_logits = torch.stack(fast_logits, dim=1)
-        assert codebook_logits.shape[1] == self.config.num_codebooks
+        for layer in self.fast_layers:
+            x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
 
-        codebook_logits = rearrange(codebook_logits, "b c n d -> b n c d")
+        # unflatten the batch and num_codebooks
+        fast_out = self.fast_norm(x)  # only take the last token
+        codebook_logits = self.fast_output(fast_out)
 
-        return TransformerForwardResult(
-            token_logits=token_logits,
-            codebook_logits=codebook_logits,
-        )
+        return codebook_logits
 
 
 class TransformerBlock(nn.Module):

+ 52 - 24
tools/llama/generate.py

@@ -98,21 +98,34 @@ def decode_one_token(
 ) -> torch.Tensor:
     assert input_pos.shape[-1] == 1
 
-    logits = model.forward_generate(x, input_pos)
+    x, logits = model.forward_generate_slow(x, input_pos)
     codebooks = [
         sample(
-            logits.token_logits,
+            logits,
             previous_tokens=None,  # Disable repetition penalty for the token codebook
             **sampling_kwargs,
         )[0]
     ]
 
-    # Disable <s> and </s> tokens for codebooks
-    if model.config.num_codebooks != 0:
-        for i in range(model.config.num_codebooks):
-            codebooks.append(
-                torch.argmax(logits.codebook_logits[:, :, i], dim=-1).view(1)
-            )
+    # Cleanup the cache
+    for layer in model.fast_layers:
+        layer.attention.kv_cache.k_cache.fill_(0)
+        layer.attention.kv_cache.v_cache.fill_(0)
+
+    for codebook_idx in range(model.config.num_codebooks):
+        input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
+        logits = model.forward_generate_fast(x, input_pos)
+        a = sample(
+            logits,
+            previous_tokens=(
+                previous_tokens[codebook_idx + 1]
+                if previous_tokens is not None
+                else None
+            ),
+            **sampling_kwargs,
+        )[0]
+        x = model.fast_embeddings(a)
+        codebooks.append(a)
 
     return torch.stack(codebooks, dim=0)
 
@@ -121,20 +134,32 @@ def prefill(
     model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
 ) -> torch.Tensor:
     # input_pos: [B, S]
-    logits = model.forward_generate(x, input_pos)
+    x, logits = model.forward_generate_slow(x, input_pos)
+    print("---", x.shape, logits.shape)
     codebooks = [
         sample(
-            logits.token_logits,
+            logits,
             previous_tokens=None,
             **sampling_kwargs,
         )[0]
     ]
 
-    if model.config.num_codebooks != 0:
-        for i in range(model.config.num_codebooks):
-            codebooks.append(
-                torch.argmax(logits.codebook_logits[:, :, i], dim=-1).view(1)
-            )
+    # Cleanup the cache
+    for layer in model.fast_layers:
+        layer.attention.kv_cache.k_cache.fill_(0)
+        layer.attention.kv_cache.v_cache.fill_(0)
+
+    for codebook_idx in range(model.config.num_codebooks):
+        input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
+        logits = model.forward_generate_fast(x, input_pos)
+        # print(x.shape, logits.shape)
+        a = sample(
+            logits,
+            previous_tokens=None,
+            **sampling_kwargs,
+        )[0]
+        x = model.fast_embeddings(a)
+        codebooks.append(a)
 
     return torch.stack(codebooks, dim=0)
 
@@ -317,7 +342,10 @@ def encode_tokens(
 
     # Since 1.0, we use <s:xxx> to replace <semantic>
     main_token_ids = torch.tensor(
-        [[tokenizer.pad_token_id] * data.size(1)], dtype=torch.int, device=device
+        # TODO: replace this
+        [[tokenizer.pad_token_id] * data.size(1)],
+        dtype=torch.int,
+        device=device,
     )
 
     data = torch.cat((main_token_ids, data), dim=0)
@@ -397,7 +425,7 @@ def split_text(text, min_length):
 @click.option("--max-new-tokens", type=int, default=0)
 @click.option("--top-k", type=int, default=None)
 @click.option("--top-p", type=float, default=0.5)
-@click.option("--repetition-penalty", type=float, default=1.5)
+@click.option("--repetition-penalty", type=float, default=1.2)
 @click.option("--temperature", type=float, default=0.7)
 @click.option(
     "--checkpoint-path",
@@ -544,14 +572,14 @@ def main(
             # Put the generated tokens
             codes = y[1:, prompt_length:-1].clone()
 
-            if getattr(cfg, "use_delay_pattern", True):
-                new_codes = []
-                for j, code in enumerate(codes):
-                    new_codes.append(
-                        code[j : codes.shape[1] - (model.config.num_codebooks - j - 1)]
-                    )
+            # if getattr(cfg, "use_delay_pattern", True):
+            #     new_codes = []
+            #     for j, code in enumerate(codes):
+            #         new_codes.append(
+            #             code[j : codes.shape[1] - (model.config.num_codebooks - j - 1)]
+            #         )
 
-                codes = torch.stack(new_codes, dim=0)
+            #     codes = torch.stack(new_codes, dim=0)
 
             codes = codes - 2
             if not (codes >= 0).all():