فهرست منبع

Add 4 in 8 out support

Lengyue 2 سال پیش
والد
کامیت
f5a2df2d23
2فایلهای تغییر یافته به همراه10 افزوده شده و 5 حذف شده
  1. 3 2
      fish_speech/configs/text2semantic_pretrain.yaml
  2. 7 3
      fish_speech/models/text2semantic/llama.py

+ 3 - 2
fish_speech/configs/text2semantic_pretrain.yaml

@@ -2,7 +2,7 @@ defaults:
   - base
   - _self_
 
-project: text2semantic_pretrain_400m_4_codebooks
+project: text2semantic_pretrain_400m_4_in_8_codebooks
 max_length: 2048
 
 # Lightning Trainer
@@ -63,7 +63,8 @@ model:
       dim: 1024
       rope_base: 10000
       norm_eps: 1e-5
-      num_codebooks: 4  # single codebook
+      num_in_codebooks: 4 # input codebook size
+      num_codebooks: 8  # output codebook size
       codebook_size: 264 # codebook size 256 + 2 special tokens
       dropout: 0.1
       neft_alpha: 10

+ 7 - 3
fish_speech/models/text2semantic/llama.py

@@ -37,6 +37,7 @@ class ModelArgs:
     # Additional decoding heads
     codebook_size: int = 160
     num_codebooks: int = 4
+    num_in_codebooks: Optional[int] = None
     codebook_padding_idx: int = 0
 
     # Use flash attention
@@ -55,6 +56,8 @@ class ModelArgs:
             hidden_dim = 4 * self.dim
             n_hidden = int(2 * hidden_dim / 3)
             self.intermediate_size = find_multiple(n_hidden, 256)
+        if self.num_in_codebooks is None:
+            self.num_in_codebooks = self.num_codebooks
         self.head_dim = self.dim // self.n_head
 
 
@@ -91,7 +94,8 @@ class Transformer(nn.Module):
         self.config = config
 
         self.embeddings = nn.Embedding(
-            config.vocab_size + config.codebook_size * config.num_codebooks, config.dim
+            config.vocab_size + config.codebook_size * config.num_in_codebooks,
+            config.dim,
         )
         self.layers = nn.ModuleList(
             TransformerBlock(config) for _ in range(config.n_layer)
@@ -148,11 +152,11 @@ class Transformer(nn.Module):
 
     def embed(self, x: Tensor) -> Tensor:
         # Here we want to merge the embeddings of the codebooks
-        if self.config.num_codebooks == 0:
+        if self.config.num_in_codebooks == 0:
             return self.embeddings(x[:, 0])
 
         vocab_embeds = [self.embeddings(x[:, 0])]
-        for i in range(self.config.num_codebooks):
+        for i in range(self.config.num_in_codebooks):
             emb = self.embeddings(
                 x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
             )