diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 6d5afdb5f..a909032c9 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -65,6 +65,7 @@ MCA_SUPPORTED_MODELS = { "qwen2_vl", "qwen2_5_vl", "qwen3_vl", + "qwen3_vl_moe", "qwen3", "qwen3_moe", "qwen3_next", diff --git a/src/llamafactory/train/mca/workflow.py b/src/llamafactory/train/mca/workflow.py index affa2efc8..819b840ef 100644 --- a/src/llamafactory/train/mca/workflow.py +++ b/src/llamafactory/train/mca/workflow.py @@ -82,9 +82,34 @@ def _check_model_support(model_args: "ModelArguments"): model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code ) if config.model_type not in MCA_SUPPORTED_MODELS: - raise ValueError(f"Model {config.model_type} is not supported by MCA.") + raise ValueError( + f"Model {config.model_type} is not supported by mcore_adapter." + "You can try to upgrade mcore_adapter to the latest version for more supported models." + ) +def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments"): + """Freeze model parameters for qwen_vl series models based on finetuning arguments.""" + if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe"]: + return + + params_to_freeze = [] + if finetuning_args.freeze_vision_tower: + params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"]) + if getattr(model.config, "hf_model_type", None) in ["qwen3_vl", "qwen3_vl_moe"]: + params_to_freeze.extend(["vision_model.pos_embed"]) + + if finetuning_args.freeze_multi_modal_projector: + params_to_freeze.extend(["multi_modal_projector"]) + + if finetuning_args.freeze_language_model: + params_to_freeze.extend(["embedding", "decoder", "output_layer"]) + + if params_to_freeze: + for name, p in model.named_parameters(): + if any(name.startswith(k) for k in params_to_freeze): + p.requires_grad_(False) + def run_pt( model_args: "ModelArguments", data_args: "DataArguments", @@ -161,22 +186,8 @@ def run_sft( _check_model_support(model_args) model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) - # optional freezing for qwen2_vl, qwen2_5_vl - if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl"]: - params_to_freeze = [] - if finetuning_args.freeze_vision_tower: - params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"]) - - if finetuning_args.freeze_multi_modal_projector: - params_to_freeze.extend(["multi_modal_projector"]) - - if finetuning_args.freeze_language_model: - params_to_freeze.extend(["embedding", "decoder", "output_layer"]) - - if params_to_freeze: - for name, p in model.named_parameters(): - if any(name.startswith(k) for k in params_to_freeze): - p.requires_grad_(False) + # optional freezing for qwen_vl series + _freeze_model_parameters(model, finetuning_args) pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1 data_collator = SFTDataCollatorWith4DAttentionMask( @@ -229,6 +240,8 @@ def run_dpo( _check_model_support(model_args) model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) + _freeze_model_parameters(model, finetuning_args) + if finetuning_args.use_ref_model: ref_config = AutoConfig.from_pretrained(model_args.model_name_or_path, training_args) ref_model = AutoModel.from_config(ref_config)