diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py index 2ea2fa1d..caf4a9b8 100644 --- a/src/llamafactory/data/preprocess.py +++ b/src/llamafactory/data/preprocess.py @@ -27,7 +27,7 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu if TYPE_CHECKING: - from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments + from transformers import PreTrainedTokenizer, ProcessorMixin from ..hparams import DataArguments from .template import Template @@ -35,11 +35,11 @@ if TYPE_CHECKING: def get_preprocess_and_print_func( data_args: "DataArguments", - training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo", "kto"], template: "Template", tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], + do_generate: bool = False, ) -> Tuple[Callable, Callable]: if stage == "pt": preprocess_func = partial( @@ -48,7 +48,7 @@ def get_preprocess_and_print_func( data_args=data_args, ) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) - elif stage == "sft" and not training_args.predict_with_generate: + elif stage == "sft" and not do_generate: if data_args.packing: if data_args.neat_packing: from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence