mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
[trainer] support early stop (#7797)
This commit is contained in:
parent
92101f34a1
commit
7f3c31f6f4
@ -435,6 +435,10 @@ class FinetuningArguments(
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to disable the shuffling of the training set."},
|
metadata={"help": "Whether or not to disable the shuffling of the training set."},
|
||||||
)
|
)
|
||||||
|
early_stopping_steps: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Number of steps to stop training if the `metric_for_best_model` does not improve."},
|
||||||
|
)
|
||||||
plot_loss: bool = field(
|
plot_loss: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to save the training loss curves."},
|
metadata={"help": "Whether or not to save the training loss curves."},
|
||||||
|
@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from transformers import PreTrainedModel
|
from transformers import EarlyStoppingCallback, PreTrainedModel
|
||||||
|
|
||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
@ -61,6 +61,9 @@ def _training_function(config: dict[str, Any]) -> None:
|
|||||||
if finetuning_args.use_swanlab:
|
if finetuning_args.use_swanlab:
|
||||||
callbacks.append(get_swanlab_callback(finetuning_args))
|
callbacks.append(get_swanlab_callback(finetuning_args))
|
||||||
|
|
||||||
|
if finetuning_args.early_stopping_steps is not None:
|
||||||
|
callbacks.append(EarlyStoppingCallback(early_stopping_patience=finetuning_args.early_stopping_steps))
|
||||||
|
|
||||||
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
|
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
|
||||||
|
|
||||||
if finetuning_args.stage == "pt":
|
if finetuning_args.stage == "pt":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user