From 1f093334d137058807a8bbaba7ef14dc5332933e Mon Sep 17 00:00:00 2001 From: BUAADreamer <1428195643@qq.com> Date: Tue, 21 May 2024 08:57:14 +0800 Subject: [PATCH] support pretraining of llava Former-commit-id: 6a4c8cf0a6a1674c693b9337f018ff8df7477f8f --- src/llamafactory/hparams/model_args.py | 4 ++++ src/llamafactory/model/loader.py | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 5885bb09..255051dc 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -85,6 +85,10 @@ class ModelArguments: default=False, metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."}, ) + tune_mm_proj: bool = field( + default=False, + metadata={"help": "Whethor or not only finetune mm_projector for MLLM."}, + ) moe_aux_loss_coef: Optional[float] = field( default=None, metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."}, diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 49b347d5..d9784593 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -163,6 +163,11 @@ def load_model( else: model.train() + if model_args.visual_inputs and model_args.tune_mm_proj: + lm_params = [param for name, param in model.named_parameters() if "language_model" in name] + for param in lm_params: + param.requires_grad_(False) + trainable_params, all_param = count_parameters(model) if is_trainable: param_stats = "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(