From 53a89f53aa312020c9228d20f5707a3d98e8e985 Mon Sep 17 00:00:00 2001 From: niuba Date: Wed, 9 Aug 2023 16:39:27 +0800 Subject: [PATCH] add last_checkpoint support Former-commit-id: 2ec68d3398d86773c9076aae6b4e868ced0513d3 --- src/llmtuner/tuner/sft/workflow.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/src/llmtuner/tuner/sft/workflow.py b/src/llmtuner/tuner/sft/workflow.py index 69d200f3..693fbd52 100644 --- a/src/llmtuner/tuner/sft/workflow.py +++ b/src/llmtuner/tuner/sft/workflow.py @@ -1,4 +1,5 @@ # Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py +import os from typing import TYPE_CHECKING, Optional, List from transformers import DataCollatorForSeq2Seq @@ -10,11 +11,14 @@ from llmtuner.extras.ploting import plot_loss from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.sft.metric import ComputeMetrics from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer +from transformers.trainer_utils import get_last_checkpoint +from llmtuner.extras.logging import reset_logging, get_logger if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments +logger = get_logger(__name__) def run_sft( model_args: "ModelArguments", @@ -57,10 +61,28 @@ def run_sft( "temperature": 0.95, "logits_processor": get_logits_processor() } - + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) # Training if training_args.do_train: - train_result = trainer.train() + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics) trainer.save_state()