hiyouga 3e729798df refactor data preprocessing, fix mllm rlhf
Former-commit-id: 3a023bca2a502810a436cfba7708df164754ea62
2024-05-24 04:08:25 +08:00

85 lines
3.0 KiB
Python

from functools import partial
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
from .processors.feedback import preprocess_feedback_dataset
from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
from .processors.pretrain import preprocess_pretrain_dataset
from .processors.supervised import (
preprocess_packed_supervised_dataset,
preprocess_supervised_dataset,
print_supervised_dataset_example,
)
from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example
if TYPE_CHECKING:
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments
from .template import Template
def get_preprocess_and_print_func(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "kto"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[Callable, Callable]:
if stage == "pt":
preprocess_func = partial(
preprocess_pretrain_dataset,
tokenizer=tokenizer,
data_args=data_args,
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not training_args.predict_with_generate:
if data_args.packing:
preprocess_func = partial(
preprocess_packed_supervised_dataset,
template=template,
tokenizer=tokenizer,
data_args=data_args,
)
else:
preprocess_func = partial(
preprocess_supervised_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
elif stage == "rm":
preprocess_func = partial(
preprocess_pairwise_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
elif stage == "kto":
preprocess_func = partial(
preprocess_feedback_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
else:
preprocess_func = partial(
preprocess_unsupervised_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
return preprocess_func, print_function