[trainer] support early stop (#7797)

This commit is contained in:
hoshi-hiyouga 2025-04-22 01:59:33 +08:00 committed by GitHub
parent 92101f34a1
commit 7f3c31f6f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 1 deletions

View File

@ -435,6 +435,10 @@ class FinetuningArguments(
default=False,
metadata={"help": "Whether or not to disable the shuffling of the training set."},
)
early_stopping_steps: Optional[int] = field(
default=None,
metadata={"help": "Number of steps to stop training if the `metric_for_best_model` does not improve."},
)
plot_loss: bool = field(
default=False,
metadata={"help": "Whether or not to save the training loss curves."},

View File

@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Optional
import torch
import torch.distributed as dist
from transformers import PreTrainedModel
from transformers import EarlyStoppingCallback, PreTrainedModel
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
@ -61,6 +61,9 @@ def _training_function(config: dict[str, Any]) -> None:
if finetuning_args.use_swanlab:
callbacks.append(get_swanlab_callback(finetuning_args))
if finetuning_args.early_stopping_steps is not None:
callbacks.append(EarlyStoppingCallback(early_stopping_patience=finetuning_args.early_stopping_steps))
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
if finetuning_args.stage == "pt":