From cea8cea9dd074a85f102fb7470c63f9ad43a4329 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Wed, 15 May 2024 14:13:26 +0800 Subject: [PATCH] Update trainer.py Former-commit-id: aa4a8933dd520227401b7041dae40fc6fb2ddaa2 --- src/llmtuner/train/sft/trainer.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/llmtuner/train/sft/trainer.py b/src/llmtuner/train/sft/trainer.py index 5f187375..35671e1b 100644 --- a/src/llmtuner/train/sft/trainer.py +++ b/src/llmtuner/train/sft/trainer.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import numpy as np import torch -from transformers import ProcessorMixin, Seq2SeqTrainer +from transformers import Seq2SeqTrainer from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger @@ -13,6 +13,7 @@ from ..utils import create_custom_optimzer, create_custom_scheduler if TYPE_CHECKING: + from transformers import ProcessorMixin from transformers.trainer import PredictionOutput from ...hparams import FinetuningArguments @@ -26,7 +27,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE. """ - def __init__(self, finetuning_args: "FinetuningArguments", processor: "ProcessorMixin", **kwargs) -> None: + def __init__( + self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs + ) -> None: super().__init__(**kwargs) self.finetuning_args = finetuning_args self.processor = processor @@ -46,6 +49,12 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): create_custom_scheduler(self.args, num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer) + def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None: + super()._save(output_dir, state_dict) + if self.processor is not None: + output_dir = output_dir if output_dir is not None else self.args.output_dir + getattr(self.processor, "image_processor").save_pretrained(output_dir) + def prediction_step( self, model: "torch.nn.Module", @@ -121,10 +130,3 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): for label, pred in zip(decoded_labels, decoded_preds): res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) writer.write("\n".join(res)) - - def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): - super().save_model(output_dir, _internal_call) - if self.processor is not None: - if output_dir is None: - output_dir = self.args.output_dir - getattr(self.processor, "image_processor").save_pretrained(output_dir)