From d3490aceb72e110fc2ba3cd905fca999c6420841 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Fri, 24 May 2024 00:23:40 +0800 Subject: [PATCH] fix paligemma sft requires transformers>=4.41.1 Former-commit-id: de0e67aff13f191fd899ad717ec349a6bdb14f2a --- src/llamafactory/data/preprocess.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py index 7bf9d4bc..4123645f 100644 --- a/src/llamafactory/data/preprocess.py +++ b/src/llamafactory/data/preprocess.py @@ -74,19 +74,21 @@ def preprocess_supervised_dataset( if processor is not None: model_inputs["pixel_values"] = [] preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor) + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["token_type_ids"] = [] for i in range(len(examples["prompt"])): if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue - if processor is not None and not hasattr(processor, "image_seq_length"): # llava case + if processor is not None and not hasattr(processor, "image_seq_length"): # llava models examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"] messages = examples["prompt"][i] + examples["response"][i] input_ids, labels = [], [] - if processor is not None and hasattr(processor, "image_seq_length"): # paligemma case + if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) input_ids += [image_token_id] * getattr(processor, "image_seq_length") labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length") @@ -120,6 +122,10 @@ def preprocess_supervised_dataset( model_inputs["labels"].append(labels) if processor is not None: model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i])) + if hasattr(processor, "image_seq_length"): # paligemma models + token_type_ids = [0] * getattr(processor, "image_seq_length") + token_type_ids += [1] * (len(input_ids) - getattr(processor, "image_seq_length")) + model_inputs["token_type_ids"].append(token_type_ids) return model_inputs @@ -183,13 +189,15 @@ def preprocess_unsupervised_dataset( if processor is not None: model_inputs["pixel_values"] = [] preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor) + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["token_type_ids"] = [] for i in range(len(examples["prompt"])): if len(examples["prompt"][i]) % 2 != 1: logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue - if processor is not None and not hasattr(processor, "image_seq_length"): # llava case + if processor is not None and not hasattr(processor, "image_seq_length"): # llava models examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"] if len(examples["response"][i]) == 1: @@ -209,7 +217,7 @@ def preprocess_unsupervised_dataset( if template.efficient_eos: labels += [tokenizer.eos_token_id] - if processor is not None and hasattr(processor, "image_seq_length"): # paligemma case + if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids