From 30a3c6e886e07591ae83b07c568d4ae17a701a6e Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 15 Jul 2024 00:55:36 +0800 Subject: [PATCH] Update preprocess.py Former-commit-id: df52fb05b1b08887288bbaab7c612b7ac27c2290 --- src/llamafactory/data/preprocess.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py index 2ea2fa1d..caf4a9b8 100644 --- a/src/llamafactory/data/preprocess.py +++ b/src/llamafactory/data/preprocess.py @@ -27,7 +27,7 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu if TYPE_CHECKING: - from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments + from transformers import PreTrainedTokenizer, ProcessorMixin from ..hparams import DataArguments from .template import Template @@ -35,11 +35,11 @@ if TYPE_CHECKING: def get_preprocess_and_print_func( data_args: "DataArguments", - training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo", "kto"], template: "Template", tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], + do_generate: bool = False, ) -> Tuple[Callable, Callable]: if stage == "pt": preprocess_func = partial( @@ -48,7 +48,7 @@ def get_preprocess_and_print_func( data_args=data_args, ) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) - elif stage == "sft" and not training_args.predict_with_generate: + elif stage == "sft" and not do_generate: if data_args.packing: if data_args.neat_packing: from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence