diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 79e69e7b..6217015d 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -435,6 +435,10 @@ class FinetuningArguments( default=False, metadata={"help": "Whether or not to disable the shuffling of the training set."}, ) + early_stopping_steps: Optional[int] = field( + default=None, + metadata={"help": "Number of steps to stop training if the `metric_for_best_model` does not improve."}, + ) plot_loss: bool = field( default=False, metadata={"help": "Whether or not to save the training loss curves."}, diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 3adb382b..cd22ba83 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Optional import torch import torch.distributed as dist -from transformers import PreTrainedModel +from transformers import EarlyStoppingCallback, PreTrainedModel from ..data import get_template_and_fix_tokenizer from ..extras import logging @@ -61,6 +61,9 @@ def _training_function(config: dict[str, Any]) -> None: if finetuning_args.use_swanlab: callbacks.append(get_swanlab_callback(finetuning_args)) + if finetuning_args.early_stopping_steps is not None: + callbacks.append(EarlyStoppingCallback(early_stopping_patience=finetuning_args.early_stopping_steps)) + callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last if finetuning_args.stage == "pt":