from functools import partial from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple from .processors.feedback import preprocess_feedback_dataset from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example from .processors.pretrain import preprocess_pretrain_dataset from .processors.supervised import ( preprocess_packed_supervised_dataset, preprocess_supervised_dataset, print_supervised_dataset_example, ) from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example if TYPE_CHECKING: from transformers import ProcessorMixin, Seq2SeqTrainingArguments from transformers.tokenization_utils import PreTrainedTokenizer from ..hparams import DataArguments from .template import Template def get_preprocess_and_print_func( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "kto"], template: "Template", tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], ) -> Tuple[Callable, Callable]: if stage == "pt": preprocess_func = partial( preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args, ) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) elif stage == "sft" and not training_args.predict_with_generate: if data_args.packing: preprocess_func = partial( preprocess_packed_supervised_dataset, template=template, tokenizer=tokenizer, data_args=data_args, ) else: preprocess_func = partial( preprocess_supervised_dataset, template=template, tokenizer=tokenizer, processor=processor, data_args=data_args, ) print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) elif stage == "rm": preprocess_func = partial( preprocess_pairwise_dataset, template=template, tokenizer=tokenizer, processor=processor, data_args=data_args, ) print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer) elif stage == "kto": preprocess_func = partial( preprocess_feedback_dataset, template=template, tokenizer=tokenizer, processor=processor, data_args=data_args, ) print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) else: preprocess_func = partial( preprocess_unsupervised_dataset, template=template, tokenizer=tokenizer, processor=processor, data_args=data_args, ) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) return preprocess_func, print_function