mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
[trainer] support early stop (#7797)
This commit is contained in:
parent
92101f34a1
commit
7f3c31f6f4
@ -435,6 +435,10 @@ class FinetuningArguments(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to disable the shuffling of the training set."},
|
||||
)
|
||||
early_stopping_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of steps to stop training if the `metric_for_best_model` does not improve."},
|
||||
)
|
||||
plot_loss: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to save the training loss curves."},
|
||||
|
@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers import PreTrainedModel
|
||||
from transformers import EarlyStoppingCallback, PreTrainedModel
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
@ -61,6 +61,9 @@ def _training_function(config: dict[str, Any]) -> None:
|
||||
if finetuning_args.use_swanlab:
|
||||
callbacks.append(get_swanlab_callback(finetuning_args))
|
||||
|
||||
if finetuning_args.early_stopping_steps is not None:
|
||||
callbacks.append(EarlyStoppingCallback(early_stopping_patience=finetuning_args.early_stopping_steps))
|
||||
|
||||
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
|
||||
|
||||
if finetuning_args.stage == "pt":
|
||||
|
Loading…
x
Reference in New Issue
Block a user