diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py index 7bf9d4bc..4123645f 100644 --- a/src/llamafactory/data/preprocess.py +++ b/src/llamafactory/data/preprocess.py @@ -74,19 +74,21 @@ def preprocess_supervised_dataset( if processor is not None: model_inputs["pixel_values"] = [] preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor) + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["token_type_ids"] = [] for i in range(len(examples["prompt"])): if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue - if processor is not None and not hasattr(processor, "image_seq_length"): # llava case + if processor is not None and not hasattr(processor, "image_seq_length"): # llava models examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"] messages = examples["prompt"][i] + examples["response"][i] input_ids, labels = [], [] - if processor is not None and hasattr(processor, "image_seq_length"): # paligemma case + if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) input_ids += [image_token_id] * getattr(processor, "image_seq_length") labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length") @@ -120,6 +122,10 @@ def preprocess_supervised_dataset( model_inputs["labels"].append(labels) if processor is not None: model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i])) + if hasattr(processor, "image_seq_length"): # paligemma models + token_type_ids = [0] * getattr(processor, "image_seq_length") + token_type_ids += [1] * (len(input_ids) - getattr(processor, "image_seq_length")) + model_inputs["token_type_ids"].append(token_type_ids) return model_inputs @@ -183,13 +189,15 @@ def preprocess_unsupervised_dataset( if processor is not None: model_inputs["pixel_values"] = [] preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor) + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["token_type_ids"] = [] for i in range(len(examples["prompt"])): if len(examples["prompt"][i]) % 2 != 1: logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue - if processor is not None and not hasattr(processor, "image_seq_length"): # llava case + if processor is not None and not hasattr(processor, "image_seq_length"): # llava models examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"] if len(examples["response"][i]) == 1: @@ -209,7 +217,7 @@ def preprocess_unsupervised_dataset( if template.efficient_eos: labels += [tokenizer.eos_token_id] - if processor is not None and hasattr(processor, "image_seq_length"): # paligemma case + if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids