Преглед на файлове

support smart warmup (float)

Lengyue преди 1 година
родител
ревизия
fadbdae513
променени са 2 файла, в които са добавени 5 реда и са изтрити 2 реда
  1. 1 1
      fish_speech/configs/text2semantic_finetune.yaml
  2. 4 1
      fish_speech/scheduler.py

+ 1 - 1
fish_speech/configs/text2semantic_finetune.yaml

@@ -72,7 +72,7 @@ model:
     lr_lambda:
       _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
       _partial_: true
-      num_warmup_steps: 100
+      num_warmup_steps: 0.1
       num_training_steps: ${trainer.max_steps}
 
 # Callbacks

+ 4 - 1
fish_speech/scheduler.py

@@ -4,11 +4,14 @@ import math
 def get_cosine_schedule_with_warmup_lr_lambda(
     current_step: int,
     *,
-    num_warmup_steps: int,
+    num_warmup_steps: int | float,
     num_training_steps: int,
     num_cycles: float = 0.5,
     final_lr_ratio: float = 0.0,
 ):
+    if 0 < num_warmup_steps < 1:  # float mode
+        num_warmup_steps = int(num_warmup_steps * num_training_steps)
+
     if current_step < num_warmup_steps:
         return float(current_step) / float(max(1, num_warmup_steps))