diff --git a/src/llmtuner/data/preprocess.py b/src/llmtuner/data/preprocess.py index be566a5b..0b467724 100644 --- a/src/llmtuner/data/preprocess.py +++ b/src/llmtuner/data/preprocess.py @@ -8,15 +8,26 @@ from .utils import Role if TYPE_CHECKING: - from transformers import Seq2SeqTrainingArguments - from transformers.tokenization_utils import AutoProcessor, PreTrainedTokenizer + from PIL import Image + from transformers import ProcessorMixin, Seq2SeqTrainingArguments + from transformers.image_processing_utils import BaseImageProcessor + from transformers.tokenization_utils import PreTrainedTokenizer from ..hparams import DataArguments from .template import Template + logger = get_logger(__name__) +def _preprocess_visual_inputs(model_inputs: Dict[str, Any], processor: "ProcessorMixin", image: "Image") -> None: + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + pixel_values = image_processor(image, return_tensors="pt")["pixel_values"][0] + if "pixel_values" not in model_inputs: + model_inputs["pixel_values"] = [] + model_inputs["pixel_values"].append(pixel_values) + + def preprocess_pretrain_dataset( examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" ) -> Dict[str, List[List[int]]]: @@ -47,10 +58,10 @@ def preprocess_pretrain_dataset( def preprocess_supervised_dataset( examples: Dict[str, List[Any]], - tokenizer: "PreTrainedTokenizer", template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], data_args: "DataArguments", - processor: "AutoProcessor" = None, ) -> Dict[str, List[List[int]]]: # build inputs with format ` X Y ` and labels with format ` ... Y ` # for multiturn examples, we only mask the prompt part in each prompt-response pair. @@ -90,17 +101,15 @@ def preprocess_supervised_dataset( model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["labels"].append(labels) if processor is not None and "images" in examples: - pixel_values = processor.image_processor(examples["images"][0], return_tensors="pt")["pixel_values"][0] - if "pixel_values" not in model_inputs: - model_inputs["pixel_values"] = [] - model_inputs["pixel_values"].append(pixel_values) + _preprocess_visual_inputs(model_inputs, processor, examples["images"][i][0]) + return model_inputs def preprocess_packed_supervised_dataset( examples: Dict[str, List[Any]], - tokenizer: "PreTrainedTokenizer", template: "Template", + tokenizer: "PreTrainedTokenizer", data_args: "DataArguments", ) -> Dict[str, List[List[int]]]: # build inputs with format ` X1 Y1 X2 Y2 ` @@ -145,8 +154,9 @@ def preprocess_packed_supervised_dataset( def preprocess_unsupervised_dataset( examples: Dict[str, List[Any]], - tokenizer: "PreTrainedTokenizer", template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], data_args: "DataArguments", ) -> Dict[str, List[List[int]]]: # build inputs with format ` X` and labels with format `Y ` @@ -176,14 +186,17 @@ def preprocess_unsupervised_dataset( model_inputs["input_ids"].append(input_ids) model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["labels"].append(labels) + if processor is not None and "images" in examples: + _preprocess_visual_inputs(model_inputs, processor, examples["images"][i][0]) return model_inputs def preprocess_pairwise_dataset( examples: Dict[str, List[Any]], - tokenizer: "PreTrainedTokenizer", template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], data_args: "DataArguments", ) -> Dict[str, List[List[int]]]: # build input pairs with format ` X`, `Y1 ` and `Y2 ` @@ -218,6 +231,8 @@ def preprocess_pairwise_dataset( model_inputs["prompt_ids"].append(prompt_ids) model_inputs["chosen_ids"].append(chosen_ids) model_inputs["rejected_ids"].append(rejected_ids) + if processor is not None and "images" in examples: + _preprocess_visual_inputs(model_inputs, processor, examples["images"][i][0]) return model_inputs @@ -248,12 +263,12 @@ def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: def get_preprocess_and_print_func( - tokenizer: "PreTrainedTokenizer", - template: "Template", data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo"], - processor: Optional["AutoProcessor"] = None, + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], ) -> Tuple[Callable, Callable]: if stage == "pt": preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args) @@ -261,25 +276,38 @@ def get_preprocess_and_print_func( elif stage == "sft" and not training_args.predict_with_generate: if data_args.packing: preprocess_func = partial( - preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args + preprocess_packed_supervised_dataset, + template=template, + tokenizer=tokenizer, + data_args=data_args, ) else: preprocess_func = partial( preprocess_supervised_dataset, - tokenizer=tokenizer, template=template, - data_args=data_args, + tokenizer=tokenizer, processor=processor, + data_args=data_args, ) + print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) elif stage == "rm": preprocess_func = partial( - preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args + preprocess_pairwise_dataset, + template=template, + tokenizer=tokenizer, + processor=processor, + data_args=data_args, ) print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer) else: preprocess_func = partial( - preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args + preprocess_unsupervised_dataset, + template=template, + tokenizer=tokenizer, + processor=processor, + data_args=data_args, ) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) + return preprocess_func, print_function