From 4487a4a790f29ca2697c6a0eae37a4b358bad755 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Fri, 26 Apr 2024 03:33:07 +0800 Subject: [PATCH] Update loader.py Former-commit-id: 3408af236f0b4ef64c3bfa791ef757828a74da7f --- src/llmtuner/data/loader.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/llmtuner/data/loader.py b/src/llmtuner/data/loader.py index fa4aa9c1..ca0d5407 100644 --- a/src/llmtuner/data/loader.py +++ b/src/llmtuner/data/loader.py @@ -16,12 +16,13 @@ from .utils import checksum, merge_dataset if TYPE_CHECKING: from datasets import Dataset, IterableDataset - from transformers import AutoProcessor, Seq2SeqTrainingArguments + from transformers import ProcessorMixin, Seq2SeqTrainingArguments from transformers.tokenization_utils import PreTrainedTokenizer from ..hparams import DataArguments, ModelArguments from .parser import DatasetAttr + logger = get_logger(__name__) @@ -114,12 +115,12 @@ def load_single_dataset( def get_dataset( - tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo"], - processor: Optional["AutoProcessor"] = None, + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"] = None, ) -> Union["Dataset", "IterableDataset"]: template = get_template_and_fix_tokenizer(tokenizer, data_args.template) if data_args.train_on_prompt and template.efficient_eos: @@ -149,7 +150,7 @@ def get_dataset( with training_args.main_process_first(desc="pre-process dataset"): preprocess_func, print_function = get_preprocess_and_print_func( - tokenizer, template, data_args, training_args, stage, processor + data_args, training_args, stage, template, tokenizer, processor ) column_names = list(next(iter(dataset)).keys()) kwargs = {}