Explorar o código

Update configs

Lengyue %!s(int64=2) %!d(string=hai) anos
pai
achega
0a0778191c

+ 6 - 7
fish_speech/configs/text2semantic_finetune.yaml

@@ -48,20 +48,19 @@ model:
   _target_: fish_speech.models.text2semantic.TextToSemantic
 
   model:
-    # ~ 130M parameters, for debug purpose
-    _target_: fish_speech.models.text2semantic.llama.Transformer
+    _target_: fish_speech.models.text2semantic.llama.DualARTransformer
     config:
-      _target_: fish_speech.models.text2semantic.llama.ModelArgs
-      max_seq_len: 4096
+      _target_: fish_speech.models.text2semantic.llama.DualARModelArgs
+      max_seq_len: ${max_length}
       vocab_size: 36408
       n_layer: 24
+      n_fast_layer: 6
       n_head: 16
       dim: 1024
       rope_base: 10000
       norm_eps: 1e-5
-      num_codebooks: 4  # single codebook
-      codebook_size: 168 # codebook size 160 + 2 special tokens
-      dropout: 0.1 # For small dataset, dropout helps to prevent overfitting
+      num_codebooks: 8  # input/output codebook size
+      codebook_size: 264 # codebook size 256 + 2 special tokens
 
   optimizer:
     _target_: torch.optim.AdamW

+ 6 - 6
fish_speech/configs/text2semantic_finetune_lora.yaml

@@ -49,19 +49,19 @@ model:
   _target_: fish_speech.models.text2semantic.TextToSemantic
 
   model:
-    _target_: fish_speech.models.text2semantic.llama.Transformer
+    _target_: fish_speech.models.text2semantic.llama.DualARTransformer
     config:
-      _target_: fish_speech.models.text2semantic.llama.ModelArgs
-      max_seq_len: 4096
+      _target_: fish_speech.models.text2semantic.llama.DualARModelArgs
+      max_seq_len: ${max_length}
       vocab_size: 36408
       n_layer: 24
+      n_fast_layer: 6
       n_head: 16
       dim: 1024
       rope_base: 10000
       norm_eps: 1e-5
-      num_codebooks: 4  # single codebook
-      codebook_size: 168 # codebook size 160 + 2 special tokens
-      dropout: 0.1 # For small dataset, dropout helps to prevent overfitting
+      num_codebooks: 8  # input/output codebook size
+      codebook_size: 264 # codebook size 256 + 2 special tokens
 
   save_lora_only: true
   lora_config:

+ 3 - 3
fish_speech/configs/text2semantic_pretrain_small.yaml

@@ -52,18 +52,18 @@ model:
   _target_: fish_speech.models.text2semantic.TextToSemantic
 
   model:
-    # ~ 130M parameters, for debug purpose
-    _target_: fish_speech.models.text2semantic.llama.DualARTransformer
+    _target_: fish_speech.models.text2semantic.llama.NaiveTransformer
     config:
       _target_: fish_speech.models.text2semantic.llama.DualARModelArgs
       max_seq_len: ${max_length}
       vocab_size: 36408
       n_layer: 12
-      n_fast_layer: 4
+      # n_fast_layer: 4
       n_head: 12
       dim: 768
       rope_base: 10000
       norm_eps: 1e-5
+      num_in_codebooks: 4
       num_codebooks: 8  # input/output codebook size
       codebook_size: 264 # codebook size 256 + 2 special tokens
 

+ 3 - 4
fish_speech/configs/text2semantic_sft_medium.yaml

@@ -62,13 +62,12 @@ model:
   _target_: fish_speech.models.text2semantic.TextToSemantic
 
   model:
-    # ~ 130M parameters, for debug purpose
-    _target_: fish_speech.models.text2semantic.llama.Transformer
+    _target_: fish_speech.models.text2semantic.llama.DualARTransformer
     config:
-      _target_: fish_speech.models.text2semantic.llama.ModelArgs
+      _target_: fish_speech.models.text2semantic.llama.DualARModelArgs
       max_seq_len: ${max_length}
       vocab_size: 36408
-      n_slow_layer: 24
+      n_layer: 24
       n_fast_layer: 6
       n_head: 16
       dim: 1024