From 01166841cfe118b8792161a43af4db416835073e Mon Sep 17 00:00:00 2001 From: Kdump Date: Wed, 26 Mar 2025 15:25:02 +0800 Subject: [PATCH] [trainer] fix wsd scheduler (#7304) * [trainer] Warmup_stable_decay supports setting the number of stable and decay steps according to the warmup_ratio ratio * Update trainer_utils.py --------- Co-authored-by: hoshi-hiyouga --- src/llamafactory/train/trainer_utils.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 198936eb..2380540e 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -515,6 +515,22 @@ def create_custom_scheduler( num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None, ) -> None: + if training_args.lr_scheduler_type == "warmup_stable_decay": + num_warmup_steps = training_args.get_warmup_steps(num_training_steps) + remaining_steps = num_training_steps - num_warmup_steps + num_stable_steps = remaining_steps // 3 # use 1/3 for stable by default + num_decay_steps = remaining_steps - num_stable_steps + scheduler_kwargs = training_args.lr_scheduler_kwargs or {} + default_kwargs = { + "num_stable_steps": num_stable_steps, + "num_decay_steps": num_decay_steps, + } + for key, value in default_kwargs.items(): + if key not in scheduler_kwargs: + scheduler_kwargs[key] = value + + training_args.lr_scheduler_kwargs = scheduler_kwargs + if optimizer is not None and isinstance(optimizer, DummyOptimizer): optimizer_dict = optimizer.optimizer_dict scheduler_dict: dict[torch.nn.Parameter, torch.optim.lr_scheduler.LRScheduler] = {}