diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 036e1a79..1c422ebf 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -37,10 +37,6 @@ if TYPE_CHECKING: from .template import Template -def pad(seq, padding_value=0): - return pad_sequence(seq, batch_first=True, padding_value=padding_value) - - def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor": r""" Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len), @@ -159,14 +155,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): if "image_bound" in features: # for minicpmv inputs features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]] - features["input_ids"] = pad( - features["input_ids"], - ) - features["position_ids"] = pad(features["position_ids"]) - features["labels"] = pad(features["labels"], padding_value=-100) - features["attention_mask"] = pad( - features["attention_mask"], - ) + features["position_ids"] = pad_sequence(features["position_ids"], batch_first=True, padding_value=0) + features["labels"] = pad_sequence(features["labels"], batch_first=True, padding_value=-100) + features["attention_mask"] = pad_sequence(features["attention_mask"], batch_first=True, padding_value=0) new_features = {} new_features.update({"data": features}) new_features.update(features) diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 2e4831a1..8adcb884 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -171,13 +171,6 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni logger.info_rank0(f"Set language model not trainable: {language_model_keys}.") forbidden_modules.update(language_model_keys) - elif model_type == "minicpmv": - if finetuning_args.freeze_vision_tower: - forbidden_modules.add("vpm") - forbidden_modules.add("apm") - forbidden_modules.add("resampler") - forbidden_modules.add("tts") - return forbidden_modules @@ -257,6 +250,12 @@ _register_composite_model( ) +_register_composite_model( + model_type="minicpmv", + vision_model_keys=["vpm", "apm", "resampler", "tts"], +) + + _register_composite_model( model_type="paligemma", )