|
@@ -4,11 +4,14 @@ import math
|
|
|
def get_cosine_schedule_with_warmup_lr_lambda(
|
|
def get_cosine_schedule_with_warmup_lr_lambda(
|
|
|
current_step: int,
|
|
current_step: int,
|
|
|
*,
|
|
*,
|
|
|
- num_warmup_steps: int,
|
|
|
|
|
|
|
+ num_warmup_steps: int | float,
|
|
|
num_training_steps: int,
|
|
num_training_steps: int,
|
|
|
num_cycles: float = 0.5,
|
|
num_cycles: float = 0.5,
|
|
|
final_lr_ratio: float = 0.0,
|
|
final_lr_ratio: float = 0.0,
|
|
|
):
|
|
):
|
|
|
|
|
+ if 0 < num_warmup_steps < 1: # float mode
|
|
|
|
|
+ num_warmup_steps = int(num_warmup_steps * num_training_steps)
|
|
|
|
|
+
|
|
|
if current_step < num_warmup_steps:
|
|
if current_step < num_warmup_steps:
|
|
|
return float(current_step) / float(max(1, num_warmup_steps))
|
|
return float(current_step) / float(max(1, num_warmup_steps))
|
|
|
|
|
|