Procházet zdrojové kódy

Optimize llama sft config

Lengyue před 2 roky
rodič
revize
55ac942c83

+ 3 - 0
fish_speech/configs/text2semantic_pretrain_small.yaml

@@ -4,6 +4,7 @@ defaults:
 
 project: text2semantic_pretrain_small_4_in_8_codebooks
 max_length: 2048
+use_delay_pattern: true
 
 # Lightning Trainer
 trainer:
@@ -28,6 +29,7 @@ train_dataset:
   use_speaker: false
   phones_prob: 0.5
   interactive_prob: 0.5
+  use_delay_pattern: ${use_delay_pattern}
 
 val_dataset:
   _target_: fish_speech.datasets.text.AutoAugTextDataset
@@ -37,6 +39,7 @@ val_dataset:
   use_speaker: false
   phones_prob: 0.5
   interactive_prob: 0.5
+  use_delay_pattern: ${use_delay_pattern}
 
 data:
   _target_: fish_speech.datasets.text.TextDataModule

+ 21 - 11
fish_speech/configs/text2semantic_sft.yaml → fish_speech/configs/text2semantic_sft_medium.yaml

@@ -2,17 +2,19 @@ defaults:
   - base
   - _self_
 
-project: text2semantic_400m_sft_1.0
+project: text2semantic_sft_medium
 max_length: 4096
+use_delay_pattern: true
 
 # Lightning Trainer
 trainer:
   accumulate_grad_batches: 1
   gradient_clip_val: 1.0
   gradient_clip_algorithm: 'norm'
-  max_steps: 20_000
+  max_steps: 10_000
   precision: bf16-true
   limit_val_batches: 10
+  val_check_interval: 1000
 
 # Dataset Configuration
 tokenizer:
@@ -22,30 +24,37 @@ tokenizer:
 # Dataset Configuration
 train_dataset:
   _target_: fish_speech.datasets.text.AutoAugTextDataset
+  use_data_server: false
+  proto_files:
+    - data/protos/sft/train_Genshin.protos
+    - data/protos/sft/sft.protos
   tokenizer: ${tokenizer}
   max_length: ${max_length}
   num_codebooks: ${model.model.config.num_codebooks}
   use_speaker: true
   phones_prob: 0.5
   interactive_prob: 0.5
-  use_negative_samples: true
+  use_delay_pattern: ${use_delay_pattern}
 
 val_dataset:
   _target_: fish_speech.datasets.text.AutoAugTextDataset
+  use_data_server: false
+  proto_files:
+    - data/protos/sft/val_Genshin.protos
   tokenizer: ${tokenizer}
   max_length: ${max_length}
   num_codebooks: ${model.model.config.num_codebooks}
   use_speaker: true
   phones_prob: 0.5
   interactive_prob: 0.5
-  use_negative_samples: true
+  use_delay_pattern: ${use_delay_pattern}
 
 data:
   _target_: fish_speech.datasets.text.TextDataModule
   train_dataset: ${train_dataset}
   val_dataset: ${val_dataset}
   num_workers: 4
-  batch_size: 4
+  batch_size: 32
   tokenizer: ${tokenizer}
   max_length: ${max_length}
 
@@ -65,16 +74,17 @@ model:
       dim: 1024
       rope_base: 10000
       norm_eps: 1e-5
-      num_codebooks: 4  # single codebook
+      num_in_codebooks: 4 # input codebook size
+      num_codebooks: 8  # output codebook size
       codebook_size: 264 # codebook size 256 + 2 special tokens
-      dropout: 0.1
-      neft_alpha: 10
+      dropout: 0
+      neft_alpha: 0
 
   optimizer:
-    _target_: torch.optim.AdamW
+    _target_: bitsandbytes.optim.AdamW8bit
     _partial_: true
     lr: 1e-4
-    weight_decay: 0.01
+    weight_decay: 0
     betas: [0.9, 0.95]
     eps: 1e-5
 
@@ -86,4 +96,4 @@ model:
       _partial_: true
       num_warmup_steps: 200
       num_training_steps: ${trainer.max_steps}
-      final_lr_ratio: 0.1
+      final_lr_ratio: 0

+ 1 - 1
fish_speech/datasets/text.py

@@ -194,7 +194,7 @@ class AutoAugTextDataset(IterableDataset):
         tokenizer: AutoTokenizer = None,
         use_speaker: bool = True,
         use_data_server: bool = True,
-        proto_files: str = "data",
+        proto_files: Optional[list[str]] = None,
         causual: bool = True,
         mix_text_phone_prob: float = 0.5,
         use_negative_samples: bool = False,

+ 15 - 11
tools/llama/generate.py

@@ -376,7 +376,7 @@ def load_model(config_name, checkpoint_path, device, precision):
     model = model.to(device=device, dtype=precision)
     logger.info("Restored model from checkpoint")
 
-    return model.eval()
+    return model.eval(), cfg
 
 
 def split_text(text, min_length):
@@ -451,7 +451,7 @@ def main(
 
     logger.info("Loading model ...")
     t0 = time.time()
-    model = load_model(config_name, checkpoint_path, device, precision)
+    model, cfg = load_model(config_name, checkpoint_path, device, precision)
     model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
 
     torch.cuda.synchronize()
@@ -466,12 +466,13 @@ def main(
 
     use_prompt = prompt_text is not None and prompt_tokens is not None
     encoded = []
-    for text in split_text(text, 20):
+    texts = split_text(text, 20) if iterative_prompt else [text]
+    for idx, text in enumerate(texts):
         encoded.append(
             encode_tokens(
                 tokenizer,
-                text,
-                bos=False,
+                string=text,
+                bos=idx == 0 and not use_prompt,
                 device=device,
                 use_g2p=use_g2p,
                 speaker=None,
@@ -553,13 +554,16 @@ def main(
 
             # Put the generated tokens
             codes = y[1:, prompt_length:-1].clone()
-            new_codes = []
-            for j, code in enumerate(codes):
-                new_codes.append(
-                    code[j : codes.shape[1] - (model.config.num_codebooks - j - 1)]
-                )
 
-            codes = torch.stack(new_codes, dim=0)
+            if getattr(cfg, "use_delay_pattern", True):
+                new_codes = []
+                for j, code in enumerate(codes):
+                    new_codes.append(
+                        code[j : codes.shape[1] - (model.config.num_codebooks - j - 1)]
+                    )
+
+                codes = torch.stack(new_codes, dim=0)
+
             codes = codes - 2
             if not (codes >= 0).all():
                 global_encoded.pop()