Explorar o código

Optimize config system & enable faster loading

Lengyue %!s(int64=2) %!d(string=hai) anos
pai
achega
ea101fcd0d

+ 9 - 0
fish_speech/configs/model/dual_ar_8_codebook_medium.yaml

@@ -0,0 +1,9 @@
+defaults:
+  - dual_ar_8_codebook_small
+  - _self_
+
+config:
+  n_layer: 24
+  n_fast_layer: 6
+  n_head: 16
+  dim: 1024

+ 13 - 0
fish_speech/configs/model/dual_ar_8_codebook_small.yaml

@@ -0,0 +1,13 @@
+_target_: fish_speech.models.text2semantic.llama.NaiveTransformer
+config:
+  _target_: fish_speech.models.text2semantic.llama.DualARModelArgs
+  max_seq_len: ${max_length}
+  vocab_size: 36408
+  n_layer: 12
+  n_fast_layer: 4
+  n_head: 12
+  dim: 768
+  rope_base: 10000
+  norm_eps: 1e-5
+  num_codebooks: 8  # input/output codebook size
+  codebook_size: 264 # codebook size 256 + 2 special tokens

+ 12 - 0
fish_speech/configs/model/naive_8_codebook_small.yaml

@@ -0,0 +1,12 @@
+_target_: fish_speech.models.text2semantic.llama.NaiveTransformer
+config:
+  _target_: fish_speech.models.text2semantic.llama.DualARModelArgs
+  max_seq_len: ${max_length}
+  vocab_size: 36408
+  n_layer: 12
+  n_head: 12
+  dim: 768
+  rope_base: 10000
+  norm_eps: 1e-5
+  num_codebooks: 8  # input/output codebook size
+  codebook_size: 264 # codebook size 256 + 2 special tokens

+ 2 - 15
fish_speech/configs/text2semantic_finetune.yaml

@@ -1,5 +1,6 @@
 defaults:
   - base
+  - model@model.model: dual_ar_8_codebook_small
   - _self_
 
 project: text2semantic_400m_finetune_spk
@@ -46,21 +47,7 @@ data:
 # Model Configuration
 model:
   _target_: fish_speech.models.text2semantic.TextToSemantic
-
-  model:
-    _target_: fish_speech.models.text2semantic.llama.DualARTransformer
-    config:
-      _target_: fish_speech.models.text2semantic.llama.DualARModelArgs
-      max_seq_len: ${max_length}
-      vocab_size: 36408
-      n_layer: 24
-      n_fast_layer: 6
-      n_head: 16
-      dim: 1024
-      rope_base: 10000
-      norm_eps: 1e-5
-      num_codebooks: 8  # input/output codebook size
-      codebook_size: 264 # codebook size 256 + 2 special tokens
+  model: {}
 
   optimizer:
     _target_: torch.optim.AdamW

+ 1 - 82
fish_speech/configs/text2semantic_finetune_lora.yaml

@@ -1,93 +1,12 @@
 defaults:
-  - base
+  - text2semantic_finetune
   - _self_
 
 project: text2semantic_400m_finetune_lora
-max_length: 4096
-ckpt_path: checkpoints/text2semantic-400m-v0.3-4k.pth
-resume_weights_only: true
-
-# Lightning Trainer
-trainer:
-  accumulate_grad_batches: 2
-  gradient_clip_val: 1.0
-  gradient_clip_algorithm: 'norm'
-  max_steps: 1000
-  precision: bf16-true
-  limit_val_batches: 10
-  log_every_n_steps: 10
-
-# Dataset Configuration
-tokenizer:
-  _target_: transformers.AutoTokenizer.from_pretrained
-  pretrained_model_name_or_path: fishaudio/speech-lm-v1
-
-# Dataset Configuration
-train_dataset:
-  _target_: fish_speech.datasets.text.AutoAugTextDataset
-  tokenizer: ${tokenizer}
-  max_length: ${max_length}
-  num_codebooks: ${model.model.config.num_codebooks}
-
-val_dataset:
-  _target_: fish_speech.datasets.text.AutoAugTextDataset
-  tokenizer: ${tokenizer}
-  max_length: ${max_length}
-  num_codebooks: ${model.model.config.num_codebooks}
-
-data:
-  _target_: fish_speech.datasets.text.TextDataModule
-  train_dataset: ${train_dataset}
-  val_dataset: ${val_dataset}
-  num_workers: 4
-  batch_size: 8
-  tokenizer: ${tokenizer}
-  max_length: ${max_length}
-
 # Model Configuration
 model:
-  _target_: fish_speech.models.text2semantic.TextToSemantic
-
-  model:
-    _target_: fish_speech.models.text2semantic.llama.DualARTransformer
-    config:
-      _target_: fish_speech.models.text2semantic.llama.DualARModelArgs
-      max_seq_len: ${max_length}
-      vocab_size: 36408
-      n_layer: 24
-      n_fast_layer: 6
-      n_head: 16
-      dim: 1024
-      rope_base: 10000
-      norm_eps: 1e-5
-      num_codebooks: 8  # input/output codebook size
-      codebook_size: 264 # codebook size 256 + 2 special tokens
-
   save_lora_only: true
   lora_config:
     _target_: fish_speech.models.text2semantic.lit_module.LoraConfig
     r: 8
     lora_alpha: 16
-
-  optimizer:
-    _target_: torch.optim.AdamW
-    _partial_: true
-    lr: 3e-4
-    weight_decay: 0.1
-    betas: [0.9, 0.95]
-    eps: 1e-5
-
-  lr_scheduler:
-    _target_: torch.optim.lr_scheduler.LambdaLR
-    _partial_: true
-    lr_lambda:
-      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
-      _partial_: true
-      num_warmup_steps: 100
-      num_training_steps: ${trainer.max_steps}
-      final_lr_ratio: 0.1
-
-# Callbacks
-callbacks:
-  model_checkpoint:
-    every_n_train_steps: 200

+ 3 - 17
fish_speech/configs/text2semantic_pretrain_small.yaml → fish_speech/configs/text2semantic_pretrain.yaml

@@ -1,8 +1,9 @@
 defaults:
   - base
+  - model@model.model: dual_ar_8_codebook_small
   - _self_
 
-project: text2semantic_pretrain_small_dual_ar
+project: text2semantic_pretrain_dual_ar_debug
 max_length: 2048
 
 # Lightning Trainer
@@ -50,22 +51,7 @@ data:
 # Model Configuration
 model:
   _target_: fish_speech.models.text2semantic.TextToSemantic
-
-  model:
-    _target_: fish_speech.models.text2semantic.llama.NaiveTransformer
-    config:
-      _target_: fish_speech.models.text2semantic.llama.DualARModelArgs
-      max_seq_len: ${max_length}
-      vocab_size: 36408
-      n_layer: 12
-      # n_fast_layer: 4
-      n_head: 12
-      dim: 768
-      rope_base: 10000
-      norm_eps: 1e-5
-      num_in_codebooks: 4
-      num_codebooks: 8  # input/output codebook size
-      codebook_size: 264 # codebook size 256 + 2 special tokens
+  model: {}
 
   optimizer:
     _target_: torch.optim.AdamW

+ 0 - 14
fish_speech/configs/text2semantic_pretrain_medium.yaml

@@ -1,14 +0,0 @@
-defaults:
-  - text2semantic_pretrain_small
-  - _self_
-
-project: text2semantic_pretrain_medium_dual_ar
-
-# Model Configuration
-model:
-  model:
-    config:
-      n_layer: 24
-      n_fast_layer: 6
-      n_head: 16
-      dim: 1024

+ 2 - 15
fish_speech/configs/text2semantic_sft_medium.yaml → fish_speech/configs/text2semantic_sft.yaml

@@ -1,5 +1,6 @@
 defaults:
   - base
+  - model@model.model: dual_ar_8_codebook_small
   - _self_
 
 project: text2semantic_sft_medium_dual_ar
@@ -60,21 +61,7 @@ data:
 # Model Configuration
 model:
   _target_: fish_speech.models.text2semantic.TextToSemantic
-
-  model:
-    _target_: fish_speech.models.text2semantic.llama.DualARTransformer
-    config:
-      _target_: fish_speech.models.text2semantic.llama.DualARModelArgs
-      max_seq_len: ${max_length}
-      vocab_size: 36408
-      n_layer: 24
-      n_fast_layer: 6
-      n_head: 16
-      dim: 1024
-      rope_base: 10000
-      norm_eps: 1e-5
-      num_codebooks: 8  # input/output codebook size
-      codebook_size: 264 # codebook size 256 + 2 special tokens
+  model: {}
 
   optimizer:
     _target_: torch.optim.AdamW

+ 30 - 23
tools/llama/generate.py

@@ -90,22 +90,21 @@ def sample(
 
 
 def decode_one_token_ar(
-    model: NaiveTransformer,
+    model: DualARTransformer,
     x: torch.Tensor,
     input_pos: torch.Tensor,
     previous_tokens: torch.Tensor = None,
     **sampling_kwargs,
 ) -> torch.Tensor:
-    assert input_pos.shape[-1] == 1
-
-    x, logits = model.forward_generate_slow(x, input_pos)
+    x = model.forward_generate(x, input_pos)
     codebooks = [
         sample(
-            logits,
+            x.logits,
             previous_tokens=None,  # Disable repetition penalty for the token codebook
             **sampling_kwargs,
         )[0]
     ]
+    x = x.hidden_states
 
     # Cleanup the cache
     for layer in model.fast_layers:
@@ -137,12 +136,11 @@ def decode_one_token_naive(
     previous_tokens: torch.Tensor = None,
     **sampling_kwargs,
 ) -> torch.Tensor:
-    assert input_pos.shape[-1] == 1
+    x = model.forward_generate(x, input_pos)
 
-    x, logits = model.forward_generate_slow(x, input_pos)
     codebooks = [
         sample(
-            logits,
+            x.token_logits,
             previous_tokens=None,  # Disable repetition penalty for the token codebook
             **sampling_kwargs,
         )[0]
@@ -151,7 +149,7 @@ def decode_one_token_naive(
     for i in range(model.config.num_codebooks):
         codebooks.append(
             sample(
-                logits.codebook_logits[:, :, i],
+                x.codebook_logits[:, :, i],
                 previous_tokens=previous_tokens[i + 1]
                 if previous_tokens is not None
                 else None,
@@ -343,11 +341,13 @@ def encode_tokens(
     return prompt
 
 
-def load_model(config_name, checkpoint_path, device, precision):
-    with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
-        cfg = compose(config_name=config_name)
+def load_model(config_name, checkpoint_path, device, precision, max_length):
+    with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
+        cfg = compose(
+            config_name=config_name, overrides=[f"config.max_seq_len={max_length}"]
+        )
 
-    model: Union[NaiveTransformer, DualARTransformer] = instantiate(cfg.model).model
+    model: Union[NaiveTransformer, DualARTransformer] = instantiate(cfg)
 
     if "int8" in str(checkpoint_path):
         logger.info("Using int8 weight-only quantization!")
@@ -421,7 +421,7 @@ def split_text(text, min_length):
     type=click.Path(path_type=Path, exists=True),
     default="results/text2semantic_400m_finetune/step_000002000.pth",
 )
-@click.option("--config-name", type=str, default="text2semantic_finetune")
+@click.option("--config-name", type=str, default="dual_ar_8_codebook_small")
 @click.option("--tokenizer", type=str, default="fishaudio/speech-lm-v1")
 @click.option("--compile/--no-compile", default=False)
 @click.option("--use-g2p/--no-g2p", default=True)
@@ -430,6 +430,8 @@ def split_text(text, min_length):
 @click.option("--order", type=str, default="zh,jp,en")
 @click.option("--half/--no-half", default=False)
 @click.option("--iterative-prompt/--no-iterative-prompt", default=False)
+@click.option("--max-length", type=int, default=2048)
+@click.option("--chunk-length", type=int, default=30)
 def main(
     text: str,
     prompt_text: Optional[str],
@@ -450,6 +452,8 @@ def main(
     order: str,
     half: bool,
     iterative_prompt: bool,
+    max_length: int,
+    chunk_length: int,
 ) -> None:
     device = "cuda"
 
@@ -457,7 +461,7 @@ def main(
 
     logger.info("Loading model ...")
     t0 = time.time()
-    model, cfg = load_model(config_name, checkpoint_path, device, precision)
+    model, cfg = load_model(config_name, checkpoint_path, device, precision, max_length)
     model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
 
     torch.cuda.synchronize()
@@ -472,7 +476,7 @@ def main(
 
     use_prompt = prompt_text is not None and prompt_tokens is not None
     encoded = []
-    texts = split_text(text, 30) if iterative_prompt else [text]
+    texts = split_text(text, chunk_length) if iterative_prompt else [text]
     for idx, text in enumerate(texts):
         encoded.append(
             encode_tokens(
@@ -506,13 +510,15 @@ def main(
     torch.manual_seed(seed)
     torch.cuda.manual_seed(seed)
 
-    decode_one_token = (
-        decode_one_token_ar
-        if isinstance(model, DualARTransformer)
-        else decode_one_token_naive
-    )
+    if isinstance(model, DualARTransformer):
+        decode_one_token = decode_one_token_ar
+        logger.info("Using DualARTransformer")
+    else:
+        decode_one_token = decode_one_token_naive
+        logger.info("Using NaiveTransformer")
 
     if compile:
+        logger.info("Compiling function...")
         decode_one_token = torch.compile(
             decode_one_token, mode="reduce-overhead", fullgraph=True
         )
@@ -528,11 +534,12 @@ def main(
             global_encoded.append(seg)
 
             lengths = reversed([seg.size(1) for seg in global_encoded])
+
             # Pick last 2000 tokens
             count = 0
             for i, length in enumerate(lengths):
                 count += length
-                if count >= 2000:
+                if count + length > max_length - 1024:
                     break
 
             if i != 0 and i % 2 == 0:
@@ -561,7 +568,7 @@ def main(
                 repetition_penalty=repetition_penalty,
             )
 
-            if idx == 0 and compile:
+            if idx == 0 and seg_idx == 0 and compile:
                 logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
 
             torch.cuda.synchronize()