diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 198936eb..2380540e 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -515,6 +515,22 @@ def create_custom_scheduler( num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None, ) -> None: + if training_args.lr_scheduler_type == "warmup_stable_decay": + num_warmup_steps = training_args.get_warmup_steps(num_training_steps) + remaining_steps = num_training_steps - num_warmup_steps + num_stable_steps = remaining_steps // 3 # use 1/3 for stable by default + num_decay_steps = remaining_steps - num_stable_steps + scheduler_kwargs = training_args.lr_scheduler_kwargs or {} + default_kwargs = { + "num_stable_steps": num_stable_steps, + "num_decay_steps": num_decay_steps, + } + for key, value in default_kwargs.items(): + if key not in scheduler_kwargs: + scheduler_kwargs[key] = value + + training_args.lr_scheduler_kwargs = scheduler_kwargs + if optimizer is not None and isinstance(optimizer, DummyOptimizer): optimizer_dict = optimizer.optimizer_dict scheduler_dict: dict[torch.nn.Parameter, torch.optim.lr_scheduler.LRScheduler] = {}