[trainer] fix wsd scheduler (#7304)

* [trainer] Warmup_stable_decay supports setting the number of stable and decay steps according to the warmup_ratio ratio

* Update trainer_utils.py

---------

Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
Kdump 2025-03-26 15:25:02 +08:00 committed by GitHub
parent 59e12bffe8
commit 01166841cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -515,6 +515,22 @@ def create_custom_scheduler(
num_training_steps: int,
optimizer: Optional["torch.optim.Optimizer"] = None,
) -> None:
if training_args.lr_scheduler_type == "warmup_stable_decay":
num_warmup_steps = training_args.get_warmup_steps(num_training_steps)
remaining_steps = num_training_steps - num_warmup_steps
num_stable_steps = remaining_steps // 3 # use 1/3 for stable by default
num_decay_steps = remaining_steps - num_stable_steps
scheduler_kwargs = training_args.lr_scheduler_kwargs or {}
default_kwargs = {
"num_stable_steps": num_stable_steps,
"num_decay_steps": num_decay_steps,
}
for key, value in default_kwargs.items():
if key not in scheduler_kwargs:
scheduler_kwargs[key] = value
training_args.lr_scheduler_kwargs = scheduler_kwargs
if optimizer is not None and isinstance(optimizer, DummyOptimizer):
optimizer_dict = optimizer.optimizer_dict
scheduler_dict: dict[torch.nn.Parameter, torch.optim.lr_scheduler.LRScheduler] = {}