mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[trainer] fix wsd scheduler (#7304)
* [trainer] Warmup_stable_decay supports setting the number of stable and decay steps according to the warmup_ratio ratio * Update trainer_utils.py --------- Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
parent
59e12bffe8
commit
01166841cf
@ -515,6 +515,22 @@ def create_custom_scheduler(
|
||||
num_training_steps: int,
|
||||
optimizer: Optional["torch.optim.Optimizer"] = None,
|
||||
) -> None:
|
||||
if training_args.lr_scheduler_type == "warmup_stable_decay":
|
||||
num_warmup_steps = training_args.get_warmup_steps(num_training_steps)
|
||||
remaining_steps = num_training_steps - num_warmup_steps
|
||||
num_stable_steps = remaining_steps // 3 # use 1/3 for stable by default
|
||||
num_decay_steps = remaining_steps - num_stable_steps
|
||||
scheduler_kwargs = training_args.lr_scheduler_kwargs or {}
|
||||
default_kwargs = {
|
||||
"num_stable_steps": num_stable_steps,
|
||||
"num_decay_steps": num_decay_steps,
|
||||
}
|
||||
for key, value in default_kwargs.items():
|
||||
if key not in scheduler_kwargs:
|
||||
scheduler_kwargs[key] = value
|
||||
|
||||
training_args.lr_scheduler_kwargs = scheduler_kwargs
|
||||
|
||||
if optimizer is not None and isinstance(optimizer, DummyOptimizer):
|
||||
optimizer_dict = optimizer.optimizer_dict
|
||||
scheduler_dict: dict[torch.nn.Parameter, torch.optim.lr_scheduler.LRScheduler] = {}
|
||||
|
Loading…
x
Reference in New Issue
Block a user