Просмотр исходного кода

Update new config and generation scripts

Lengyue 2 лет назад
Родитель
Сommit
5f9e2f2380

+ 1 - 1
fish_speech/configs/text2semantic_finetune.yaml

@@ -65,7 +65,7 @@ model:
     _target_: torch.optim.AdamW
     _partial_: true
     lr: 1e-5
-    weight_decay: 0.1
+    weight_decay: 0
     betas: [0.9, 0.95]
     eps: 1e-5
 

+ 3 - 3
fish_speech/configs/text2semantic_pretrain.yaml

@@ -10,7 +10,7 @@ trainer:
   accumulate_grad_batches: 1
   gradient_clip_val: 1.0
   gradient_clip_algorithm: 'norm'
-  max_steps: 1_000_000
+  max_steps: 100_000
   precision: bf16-true
   limit_val_batches: 10
 
@@ -25,7 +25,7 @@ train_dataset:
   tokenizer: ${tokenizer}
   max_length: ${max_length}
   use_speaker: false
-  phones_prob: 1.0
+  phones_prob: 0.5
   interactive_prob: 0.5
 
 val_dataset:
@@ -33,7 +33,7 @@ val_dataset:
   tokenizer: ${tokenizer}
   max_length: ${max_length}
   use_speaker: false
-  phones_prob: 1.0
+  phones_prob: 0.5
   interactive_prob: 0.5
 
 data:

+ 85 - 0
fish_speech/configs/text2semantic_sft.yaml

@@ -0,0 +1,85 @@
+defaults:
+  - base
+  - _self_
+
+project: text2semantic_400m_sft_1.0
+max_length: 4096
+
+# Lightning Trainer
+trainer:
+  accumulate_grad_batches: 1
+  gradient_clip_val: 1.0
+  gradient_clip_algorithm: 'norm'
+  max_steps: 20_000
+  precision: bf16-true
+  limit_val_batches: 10
+
+# Dataset Configuration
+tokenizer:
+  _target_: transformers.AutoTokenizer.from_pretrained
+  pretrained_model_name_or_path: fishaudio/speech-lm-v1
+
+# Dataset Configuration
+train_dataset:
+  _target_: fish_speech.datasets.text.AutoAugTextDataset
+  tokenizer: ${tokenizer}
+  max_length: ${max_length}
+  use_speaker: true
+  phones_prob: 0.5
+  interactive_prob: 0.5
+
+val_dataset:
+  _target_: fish_speech.datasets.text.AutoAugTextDataset
+  tokenizer: ${tokenizer}
+  max_length: ${max_length}
+  use_speaker: true
+  phones_prob: 0.5
+  interactive_prob: 0.5
+
+data:
+  _target_: fish_speech.datasets.text.TextDataModule
+  train_dataset: ${train_dataset}
+  val_dataset: ${val_dataset}
+  num_workers: 4
+  batch_size: 32
+  tokenizer: ${tokenizer}
+  max_length: ${max_length}
+
+# Model Configuration
+model:
+  _target_: fish_speech.models.text2semantic.TextToSemantic
+
+  model:
+    # ~ 130M parameters, for debug purpose
+    _target_: fish_speech.models.text2semantic.llama.Transformer
+    config:
+      _target_: fish_speech.models.text2semantic.llama.ModelArgs
+      max_seq_len: 4096
+      vocab_size: 36408
+      n_layer: 24
+      n_head: 16
+      dim: 1024
+      rope_base: 10000
+      norm_eps: 1e-5
+      num_codebooks: 4  # single codebook
+      codebook_size: 264 # codebook size 256 + 2 special tokens
+      dropout: 0.1
+      neft_alpha: 10
+
+  optimizer:
+    _target_: torch.optim.AdamW
+    _partial_: true
+    lr: 1e-4
+    weight_decay: 0.01
+    betas: [0.9, 0.95]
+    eps: 1e-5
+
+  lr_scheduler:
+    _target_: torch.optim.lr_scheduler.LambdaLR
+    _partial_: true
+    lr_lambda:
+      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+      _partial_: true
+      num_warmup_steps: 2000
+      num_training_steps: ${trainer.max_steps}
+      final_lr_ratio: 0.1

+ 26 - 8
tools/llama/generate.py

@@ -264,16 +264,12 @@ def encode_tokens(
     string,
     bos=True,
     device="cuda",
-    prompt_text=None,
     prompt_tokens=None,
     use_g2p=False,
     speaker=None,
     order="zh,jp,en",
     num_codebooks=4,
 ):
-    if prompt_text is not None:
-        string = prompt_text + " " + string
-
     if use_g2p:
         order = order.split(",")
         prompt = g2p(string, order=order)
@@ -306,6 +302,12 @@ def encode_tokens(
         return prompt
 
     # Get prompt tokens
+    if prompt_tokens.ndim == 3:
+        assert (
+            prompt_tokens.shape[0] == 1
+        ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
+        prompt_tokens = prompt_tokens[0]
+
     assert prompt_tokens.ndim == 2
     data = prompt_tokens + 2
 
@@ -432,18 +434,34 @@ def main(
         if prompt_tokens is not None
         else None
     )
+
+    use_prompt = prompt_text is not None and prompt_tokens is not None
+
     encoded = encode_tokens(
         tokenizer,
         text,
-        prompt_text=prompt_text,
-        prompt_tokens=prompt_tokens,
-        bos=True,
+        bos=False if use_prompt else True,
         device=device,
         use_g2p=use_g2p,
-        speaker=speaker,
+        speaker=None if use_prompt else speaker,
         order=order,
         num_codebooks=model.config.num_codebooks,
     )
+
+    if use_prompt:
+        encoded_prompt = encode_tokens(
+            tokenizer,
+            prompt_text,
+            prompt_tokens=prompt_tokens,
+            bos=True,
+            device=device,
+            use_g2p=use_g2p,
+            speaker=speaker,
+            order=order,
+            num_codebooks=model.config.num_codebooks,
+        )
+        encoded = torch.cat((encoded_prompt, encoded), dim=1)
+
     prompt_length = encoded.size(1)
     logger.info(f"Encoded prompt shape: {encoded.shape}")