瀏覽代碼

support basic TTS inference

Lengyue 1 年之前
父節點
當前提交
da29c49fef
共有 2 個文件被更改,包括 92 次插入16 次删除
  1. 72 10
      fish_speech/models/text2semantic/llama.py
  2. 20 6
      tools/llama/generate.py

+ 72 - 10
fish_speech/models/text2semantic/llama.py

@@ -1,3 +1,4 @@
+import dataclasses
 import json
 import math
 from collections import OrderedDict
@@ -57,6 +58,10 @@ class BaseModelArgs:
     # Initialize the model
     initializer_range: float = 0.02
 
+    # Dummy vars
+    is_reward_model: bool = False
+    share_codebook_embeddings: bool = True
+
     def __post_init__(self):
         if self.n_local_heads == -1:
             self.n_local_heads = self.n_head
@@ -100,6 +105,28 @@ class NaiveModelArgs(BaseModelArgs):
 class DualARModelArgs(BaseModelArgs):
     model_type: str = "dual_ar"
     n_fast_layer: int = 4
+    fast_dim: int | None = None
+    fast_n_head: int | None = None
+    fast_n_local_heads: int | None = None
+    fast_head_dim: int | None = None
+    fast_intermediate_size: int | None = None
+    fast_attention_qkv_bias: bool | None = None
+
+    def __post_init__(self):
+        super().__post_init__()
+
+        self.fast_dim = self.fast_dim or self.dim
+        self.fast_n_head = self.fast_n_head or self.n_head
+        self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads
+        self.fast_head_dim = self.fast_head_dim or self.head_dim
+        self.fast_intermediate_size = (
+            self.fast_intermediate_size or self.intermediate_size
+        )
+        self.fast_attention_qkv_bias = (
+            self.fast_attention_qkv_bias
+            if self.fast_attention_qkv_bias is not None
+            else self.attention_qkv_bias
+        )
 
 
 class KVCache(nn.Module):
@@ -474,20 +501,46 @@ class DualARTransformer(BaseTransformer):
     def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
         super().__init__(config, init_weights=False, tokenizer=tokenizer)
 
+        # Project to fast dim if needed
+        if config.fast_dim is not None and config.fast_dim != config.dim:
+            self.fast_project_in = nn.Linear(config.dim, config.fast_dim)
+        else:
+            self.fast_project_in = nn.Identity()
+
         # Fast transformer
-        self.fast_embeddings = nn.Embedding(config.codebook_size, config.dim)
+        self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim)
 
         # The equivalent bs is so large that sdpa doesn't work
+        override_config = dataclasses.replace(
+            config,
+            dim=config.fast_dim,
+            n_head=config.fast_n_head,
+            n_local_heads=config.fast_n_local_heads,
+            head_dim=config.fast_head_dim,
+            intermediate_size=config.fast_intermediate_size,
+            attention_qkv_bias=config.fast_attention_qkv_bias,
+        )
+
         self.fast_layers = nn.ModuleList(
-            TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer)
+            TransformerBlock(override_config, use_sdpa=False)
+            for _ in range(config.n_fast_layer)
         )
-        self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps)
+        self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps)
         self.fast_output = nn.Linear(
-            config.dim,
+            config.fast_dim,
             config.codebook_size,
             bias=False,
         )
 
+        self.register_buffer(
+            "fast_freqs_cis",
+            precompute_freqs_cis(
+                config.num_codebooks,
+                config.fast_dim // config.fast_n_head,
+                config.rope_base,
+            ),
+            persistent=False,
+        )
         self.apply(self._init_weights)
 
     def setup_caches(
@@ -495,7 +548,7 @@ class DualARTransformer(BaseTransformer):
     ):
         super().setup_caches(max_batch_size, max_seq_len, dtype)
 
-        head_dim = self.config.dim // self.config.n_head
+        head_dim = self.config.fast_dim // self.config.fast_n_head
 
         # Fast transformer
         # The max seq len here is the number of codebooks
@@ -503,7 +556,7 @@ class DualARTransformer(BaseTransformer):
             b.attention.kv_cache = KVCache(
                 max_batch_size,
                 self.config.num_codebooks,
-                self.config.n_local_heads,
+                self.config.fast_n_local_heads,
                 head_dim,
                 dtype=dtype,
             )
@@ -516,13 +569,13 @@ class DualARTransformer(BaseTransformer):
         parent_result = super().forward(inp, key_padding_mask)
         token_logits = parent_result.logits
         x = parent_result.hidden_states
+        x = self.fast_project_in(x)
 
         # Fast transformer
         fast_seq_len = self.config.num_codebooks
         fast_mask = self.causal_mask[
             None, None, :fast_seq_len, :fast_seq_len
         ]  # (B, N, Q, K)
-        fast_freqs_cis = self.freqs_cis[:fast_seq_len]
 
         # Drop the last token and rotate left
         codebooks = inp[:, 1:-1, 1:]
@@ -545,9 +598,11 @@ class DualARTransformer(BaseTransformer):
 
         for layer in self.fast_layers:
             if self.config.use_gradient_checkpointing and self.training:
-                x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
+                x = checkpoint(
+                    layer, x, self.fast_freqs_cis, fast_mask, use_reentrant=True
+                )
             else:
-                x = layer(x, fast_freqs_cis, fast_mask)
+                x = layer(x, self.fast_freqs_cis, fast_mask)
 
         # unflatten the batch and num_codebooks
         fast_out = self.fast_norm(x)
@@ -587,7 +642,7 @@ class DualARTransformer(BaseTransformer):
         fast_mask = self.causal_mask[
             None, None, input_pos, : self.config.num_codebooks
         ]  # (B, N, Q, K)
-        fast_freqs_cis = self.freqs_cis[input_pos]
+        fast_freqs_cis = self.fast_freqs_cis[input_pos]
 
         for layer in self.fast_layers:
             x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
@@ -598,6 +653,13 @@ class DualARTransformer(BaseTransformer):
 
         return codebook_logits
 
+    def forward_generate(
+        self, x: Tensor, input_pos: Optional[Tensor] = None
+    ) -> TransformerForwardResult:
+        x = super().forward_generate(x, input_pos)
+        x.hidden_states = self.fast_project_in(x.hidden_states)
+        return x
+
 
 class TransformerBlock(nn.Module):
     def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:

+ 20 - 6
tools/llama/generate.py

@@ -91,14 +91,15 @@ def decode_one_token_ar(
     x: torch.Tensor,
     input_pos: torch.Tensor,
     previous_tokens: torch.Tensor = None,
+    semantic_id: int = 0,
     **sampling_kwargs,
 ) -> torch.Tensor:
     x = model.forward_generate(x, input_pos)
 
     sampling_kwargs_main = sampling_kwargs.copy()
-    sampling_kwargs_main["temperature"] = 0.1
-    sampling_kwargs_main["top_p"] = 0.1
-    sampling_kwargs_main["repetition_penalty"] = 1.0
+    # sampling_kwargs_main["temperature"] = 0.1
+    # sampling_kwargs_main["top_p"] = 0.1
+    # sampling_kwargs_main["repetition_penalty"] = 1.0
 
     codebooks = [
         sample(
@@ -130,7 +131,12 @@ def decode_one_token_ar(
         x = model.fast_embeddings(a)
         codebooks.append(a)
 
-    return torch.stack(codebooks, dim=0)
+    codebooks = torch.stack(codebooks, dim=0)
+    codebooks[1:, :] = torch.masked_fill(
+        codebooks[1:, :], codebooks[:1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
+    )
+
+    return codebooks
 
 
 def decode_one_token_naive(
@@ -176,6 +182,7 @@ def decode_n_tokens(
     num_new_tokens: int,
     im_end_id: int = 4,
     decode_one_token=decode_one_token_naive,
+    semantic_id: int = 0,
     **sampling_kwargs,
 ):
     previous_tokens = torch.zeros(
@@ -204,6 +211,7 @@ def decode_n_tokens(
                 x=cur_token,
                 input_pos=input_pos,
                 previous_tokens=window,
+                semantic_id=semantic_id,
                 **sampling_kwargs,
             )
 
@@ -236,6 +244,7 @@ def generate(
 
     # create an empty tensor of the expected final shape and fill in the current tokens
     T = prompt.size(1)
+    semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
 
     if max_new_tokens:
         if T + max_new_tokens > model.config.max_seq_len:
@@ -266,7 +275,11 @@ def generate(
     )
 
     next_token = prefill_decode(
-        model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
+        model,
+        prompt.view(1, codebook_dim, -1),
+        input_pos,
+        semantic_id=semantic_id,
+        **sampling_kwargs,
     )
     seq[:, T : T + 1] = next_token
 
@@ -278,6 +291,7 @@ def generate(
         max_new_tokens - 1,
         im_end_id=im_end_id,
         decode_one_token=decode_one_token,
+        semantic_id=semantic_id,
         **sampling_kwargs,
     )
     # x = torch.cat(generated_tokens, dim=1)
@@ -295,7 +309,7 @@ def encode_tokens(
     num_codebooks=4,
 ):
     string = clean_text(string)
-    string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
+    string = f"<|im_start|>user\nSpeak: {string}<|im_end|><|im_start|>assistant\n"
 
     new_tokens = tokenizer.encode(
         string,