diff --git a/src/llmtuner/tuner/sft/workflow.py b/src/llmtuner/tuner/sft/workflow.py index b28fa1dc..ebb16edd 100644 --- a/src/llmtuner/tuner/sft/workflow.py +++ b/src/llmtuner/tuner/sft/workflow.py @@ -1,4 +1,5 @@ # Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py +import os from typing import TYPE_CHECKING, Optional, List from transformers import DataCollatorForSeq2Seq @@ -10,11 +11,14 @@ from llmtuner.extras.ploting import plot_loss from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.sft.metric import ComputeMetrics from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer +from transformers.trainer_utils import get_last_checkpoint +from llmtuner.extras.logging import reset_logging, get_logger if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments +logger = get_logger(__name__) def run_sft( model_args: "ModelArguments", @@ -58,7 +62,12 @@ def run_sft( # Training if training_args.do_train: - train_result = trainer.train() + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics) trainer.save_state()