From b9eeaa97068cb273c3ddf644db57366aabdd8823 Mon Sep 17 00:00:00 2001 From: fzc8578 <1428195643@qq.com> Date: Mon, 6 Jan 2025 19:32:39 +0800 Subject: [PATCH] add some Former-commit-id: 785cc70ff205f5962c3ca67f453589e4a471ba8c --- src/llamafactory/data/mm_plugin.py | 1 + src/llamafactory/model/model_utils/visual.py | 6 ++++++ src/llamafactory/model/patcher.py | 13 +++++++------ src/llamafactory/train/sft/trainer.py | 2 +- 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 0fbb5b1d..013672c0 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -366,6 +366,7 @@ class CpmOPlugin(BasePlugin): position_ids_ = list(range(input_ids_.size(0))) # print(input_ids_.shape, len(position_ids_) position_ids.append(position_ids_) + #TODO add pad position_ids = torch.tensor(position_ids, dtype=torch.int64) mm_inputs.update({ "image_bound": image_bounds_list, diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 246b9028..d0649514 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -142,6 +142,10 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni forbidden_modules.update({"visual.patch_embed", "visual.blocks", "model", "lm_head"}) elif finetuning_args.freeze_vision_tower: forbidden_modules.add("visual") + + elif model_type == "minicpmv": + if finetuning_args.freeze_vision_tower: + forbidden_modules.add("vpm") return forbidden_modules @@ -196,6 +200,8 @@ def patch_target_modules( return "^(?!.*vision_model).*(?:{}).*".format("|".join(target_modules)) elif model_type == "qwen2_vl": return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules)) + elif model_type == "minicpmv": + return "^(?!.*vpm).*(?:{}).*".format("|".join(target_modules)) else: return target_modules else: diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 2ce84e86..7fe8c023 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -138,12 +138,13 @@ def patch_model( add_valuehead: bool, ) -> None: gen_config = model.generation_config # check and fix generation config - if not gen_config.do_sample and ( - (gen_config.temperature is not None and gen_config.temperature != 1.0) - or (gen_config.top_p is not None and gen_config.top_p != 1.0) - or (gen_config.typical_p is not None and gen_config.typical_p != 1.0) - ): - gen_config.do_sample = True + if gen_config is not None: + if not gen_config.do_sample and ( + (gen_config.temperature is not None and gen_config.temperature != 1.0) + or (gen_config.top_p is not None and gen_config.top_p != 1.0) + or (gen_config.typical_p is not None and gen_config.typical_p != 1.0) + ): + gen_config.do_sample = True if "GenerationMixin" not in str(model.generate.__func__): model.generate = MethodType(PreTrainedModel.generate, model) diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 28ec25eb..3497af8c 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -24,6 +24,7 @@ import numpy as np import torch from transformers import Seq2SeqTrainer from typing_extensions import override +import copy from ...extras import logging from ...extras.constants import IGNORE_INDEX @@ -122,7 +123,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): labels = inputs.pop("labels", None) else: labels = inputs.get("labels") - loss, generated_tokens, _ = super().prediction_step( model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs )