فهرست منبع

Add sft & interactive config

Lengyue 2 سال پیش
والد
کامیت
6a85ec10c7
3فایلهای تغییر یافته به همراه13 افزوده شده و 17 حذف شده
  1. 1 1
      fish_speech/configs/text2semantic_pretrain_small.yaml
  2. 10 15
      fish_speech/configs/text2semantic_sft_medium.yaml
  3. 2 1
      tools/llama/generate.py

+ 1 - 1
fish_speech/configs/text2semantic_pretrain_small.yaml

@@ -43,7 +43,7 @@ data:
   train_dataset: ${train_dataset}
   train_dataset: ${train_dataset}
   val_dataset: ${val_dataset}
   val_dataset: ${val_dataset}
   num_workers: 4
   num_workers: 4
-  batch_size: 32
+  batch_size: 8
   tokenizer: ${tokenizer}
   tokenizer: ${tokenizer}
   max_length: ${max_length}
   max_length: ${max_length}
 
 

+ 10 - 15
fish_speech/configs/text2semantic_sft_medium.yaml

@@ -2,10 +2,9 @@ defaults:
   - base
   - base
   - _self_
   - _self_
 
 
-project: text2semantic_sft_medium_delay
+project: text2semantic_sft_medium_dual_ar
 max_length: 4096
 max_length: 4096
-use_delay_pattern: false
-ckpt_path: results/text2semantic_pretrain_medium_4_in_8_codebooks/checkpoints/step_000100000.ckpt
+ckpt_path: results/text2semantic_pretrain_medium_dual_ar/checkpoints/step_000060000.ckpt
 resume_weights_only: true
 resume_weights_only: true
 
 
 # Lightning Trainer
 # Lightning Trainer
@@ -33,10 +32,9 @@ train_dataset:
   tokenizer: ${tokenizer}
   tokenizer: ${tokenizer}
   max_length: ${max_length}
   max_length: ${max_length}
   num_codebooks: ${model.model.config.num_codebooks}
   num_codebooks: ${model.model.config.num_codebooks}
-  use_speaker: true
+  use_speaker: false
   phones_prob: 0.5
   phones_prob: 0.5
   interactive_prob: 0.5
   interactive_prob: 0.5
-  use_delay_pattern: ${use_delay_pattern}
 
 
 val_dataset:
 val_dataset:
   _target_: fish_speech.datasets.text.AutoAugTextDataset
   _target_: fish_speech.datasets.text.AutoAugTextDataset
@@ -46,17 +44,16 @@ val_dataset:
   tokenizer: ${tokenizer}
   tokenizer: ${tokenizer}
   max_length: ${max_length}
   max_length: ${max_length}
   num_codebooks: ${model.model.config.num_codebooks}
   num_codebooks: ${model.model.config.num_codebooks}
-  use_speaker: true
+  use_speaker: false
   phones_prob: 0.5
   phones_prob: 0.5
   interactive_prob: 0.5
   interactive_prob: 0.5
-  use_delay_pattern: ${use_delay_pattern}
 
 
 data:
 data:
   _target_: fish_speech.datasets.text.TextDataModule
   _target_: fish_speech.datasets.text.TextDataModule
   train_dataset: ${train_dataset}
   train_dataset: ${train_dataset}
   val_dataset: ${val_dataset}
   val_dataset: ${val_dataset}
   num_workers: 4
   num_workers: 4
-  batch_size: 16
+  batch_size: 8
   tokenizer: ${tokenizer}
   tokenizer: ${tokenizer}
   max_length: ${max_length}
   max_length: ${max_length}
 
 
@@ -69,21 +66,19 @@ model:
     _target_: fish_speech.models.text2semantic.llama.Transformer
     _target_: fish_speech.models.text2semantic.llama.Transformer
     config:
     config:
       _target_: fish_speech.models.text2semantic.llama.ModelArgs
       _target_: fish_speech.models.text2semantic.llama.ModelArgs
-      max_seq_len: 4096
+      max_seq_len: ${max_length}
       vocab_size: 36408
       vocab_size: 36408
-      n_layer: 24
+      n_slow_layer: 24
+      n_fast_layer: 6
       n_head: 16
       n_head: 16
       dim: 1024
       dim: 1024
       rope_base: 10000
       rope_base: 10000
       norm_eps: 1e-5
       norm_eps: 1e-5
-      num_in_codebooks: 4 # input codebook size
-      num_codebooks: 8  # output codebook size
+      num_codebooks: 8  # input/output codebook size
       codebook_size: 264 # codebook size 256 + 2 special tokens
       codebook_size: 264 # codebook size 256 + 2 special tokens
-      dropout: 0
-      neft_alpha: 0
 
 
   optimizer:
   optimizer:
-    _target_: bitsandbytes.optim.AdamW8bit
+    _target_: torch.optim.AdamW
     _partial_: true
     _partial_: true
     lr: 4e-5
     lr: 4e-5
     weight_decay: 0
     weight_decay: 0

+ 2 - 1
tools/llama/generate.py

@@ -366,9 +366,10 @@ def encode_tokens(
         data = data[:num_codebooks]
         data = data[:num_codebooks]
 
 
     # Since 1.0, we use <s:xxx> to replace <semantic>
     # Since 1.0, we use <s:xxx> to replace <semantic>
+    s0_token_id = tokenizer.convert_tokens_to_ids("<s:0>")
     main_token_ids = torch.tensor(
     main_token_ids = torch.tensor(
         # TODO: replace this
         # TODO: replace this
-        [[tokenizer.pad_token_id] * data.size(1)],
+        [[s0_token_id] * data.size(1)],
         dtype=torch.int,
         dtype=torch.int,
         device=device,
         device=device,
     )
     )