diff --git a/examples/train_full/qwen2vl_full_sft.yaml b/examples/train_full/qwen2vl_full_sft.yaml index 855c98ac..5a36a1e9 100644 --- a/examples/train_full/qwen2vl_full_sft.yaml +++ b/examples/train_full/qwen2vl_full_sft.yaml @@ -7,6 +7,7 @@ stage: sft do_train: true finetuning_type: full freeze_vision_tower: true # choices: [true, false] +freeze_multi_modal_projector: true # choices: [true, false] train_mm_proj_only: false # choices: [true, false] deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json] @@ -29,7 +30,7 @@ overwrite_output_dir: true per_device_train_batch_size: 1 gradient_accumulation_steps: 2 learning_rate: 1.0e-5 -num_train_epochs: 30.0 +num_train_epochs: 3.0 lr_scheduler_type: cosine warmup_ratio: 0.1 bf16: true diff --git a/setup.py b/setup.py index 6a7c2791..10d93551 100644 --- a/setup.py +++ b/setup.py @@ -71,7 +71,7 @@ def main(): name="llamafactory", version=get_version(), author="hiyouga", - author_email="hiyouga" "@" "buaa.edu.cn", + author_email="hiyouga AT buaa.edu.cn", description="Easy-to-use LLM fine-tuning framework", long_description=open("README.md", encoding="utf-8").read(), long_description_content_type="text/markdown", diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 490b571d..682ccb10 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -396,8 +396,7 @@ _register_template( format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]), format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n\n"]), default_system=( - "Below is an instruction that describes a task. " - "Write a response that appropriately completes the request.\n\n" + "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" ), replace_jinja_template=True, ) diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 29e91a27..5770e395 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -364,6 +364,10 @@ class FinetuningArguments( default=True, metadata={"help": "Whether ot not to freeze vision tower in MLLM training."}, ) + freeze_multi_modal_projector: bool = field( + default=True, + metadata={"help": "Whether or not to freeze the multi modal projector in MLLM training."}, + ) train_mm_proj_only: bool = field( default=False, metadata={"help": "Whether or not to train the multimodal projector for MLLM only."}, @@ -398,6 +402,7 @@ class FinetuningArguments( self.additional_target: Optional[List[str]] = split_arg(self.additional_target) self.galore_target: List[str] = split_arg(self.galore_target) self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only + self.freeze_multi_modal_projector = self.freeze_multi_modal_projector and not self.train_mm_proj_only self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"] assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 022cce06..52815bb2 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict import torch @@ -202,12 +203,8 @@ def load_model( logger.info_rank0(param_stats) - if model_args.print_param_status: + if model_args.print_param_status and int(os.getenv("LOCAL_RANK", "0")) == 0: for name, param in model.named_parameters(): - print( - "name: {}, dtype: {}, device: {}, trainable: {}".format( - name, param.dtype, param.device, param.requires_grad - ) - ) + print(f"name: {name}, dtype: {param.dtype}, device: {param.device}, trainable: {param.requires_grad}") return model diff --git a/src/llamafactory/model/model_utils/misc.py b/src/llamafactory/model/model_utils/misc.py index 5f4b747e..5c5178d4 100644 --- a/src/llamafactory/model/model_utils/misc.py +++ b/src/llamafactory/model/model_utils/misc.py @@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, List from ...extras import logging +from .visual import COMPOSITE_MODELS if TYPE_CHECKING: @@ -34,18 +35,12 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) forbidden_modules.add("output_layer") elif model_type == "internlm2": forbidden_modules.add("output") - elif model_type in ["llava", "llava_next", "llava_next_video", "mllama", "paligemma", "video_llava"]: - forbidden_modules.add("multi_modal_projector") - elif model_type == "qwen2_vl": - forbidden_modules.add("merger") - if freeze_vision_tower: - if model_type == "mllama": - forbidden_modules.add("vision_model") - elif model_type == "qwen2_vl": - forbidden_modules.add("visual") - else: - forbidden_modules.add("vision_tower") + if model_type in COMPOSITE_MODELS: + forbidden_modules.add(COMPOSITE_MODELS[model_type].projector_key) + + if freeze_vision_tower and model_type in COMPOSITE_MODELS: + forbidden_modules.update(COMPOSITE_MODELS[model_type].vision_model_keys) module_names = set() for name, module in model.named_modules(): diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 246b9028..066fa979 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -15,7 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Sequence, Set, Tuple, Union +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, Union import torch import transformers @@ -35,6 +36,40 @@ logger = logging.get_logger(__name__) transformers_logger = transformers.utils.logging.get_logger(__name__) +@dataclass +class CompositeModel: + model_type: str + projector_key: str + vision_model_keys: List[str] + language_model_keys: List[str] + + def get_projector(self, module: "torch.nn.Module") -> "torch.nn.Module": + for key in self.projector_key.split("."): + module = getattr(module, key) + + return module + + +COMPOSITE_MODELS: Dict[str, "CompositeModel"] = {} + + +def _register_composite_model( + model_type: str, + projector_key: Optional[str] = None, + vision_model_keys: Optional[List[str]] = None, + language_model_keys: Optional[List[str]] = None, +): + projector_key = projector_key or "multi_modal_projector" + vision_model_keys = vision_model_keys or ["vision_tower"] + language_model_keys = language_model_keys or ["language_model"] + COMPOSITE_MODELS[model_type] = CompositeModel( + model_type=model_type, + projector_key=projector_key, + vision_model_keys=vision_model_keys, + language_model_keys=language_model_keys, + ) + + class LlavaMultiModalProjectorForYiVL(torch.nn.Module): def __init__(self, config: "LlavaConfig") -> None: super().__init__() @@ -92,10 +127,8 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen if getattr(model, "quantization_method", None): model_type = getattr(model.config, "model_type", None) - if model_type in ["llava", "llava_next", "llava_next_video", "mllama", "paligemma", "video_llava"]: - mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector") - elif model_type == "qwen2_vl": - mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger") + if model_type in COMPOSITE_MODELS: + mm_projector = COMPOSITE_MODELS[model_type].get_projector(model) else: return @@ -107,8 +140,7 @@ def configure_visual_model(config: "PretrainedConfig") -> None: r""" Patches VLMs before loading them. """ - model_type = getattr(config, "model_type", None) - if model_type in ["llava", "llava_next", "llava_next_video", "mllama", "paligemma", "video_llava"]: + if getattr(config, "text_config", None) and not getattr(config, "hidden_size", None): # required for ds zero3 and valuehead models setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None)) @@ -123,25 +155,21 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni """ model_type = getattr(config, "model_type", None) forbidden_modules = set() - if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]: + if model_type in COMPOSITE_MODELS: if finetuning_args.freeze_vision_tower: - forbidden_modules.add("vision_tower") + vision_model_keys = COMPOSITE_MODELS[model_type].vision_model_keys + logger.info_rank0(f"Set vision model not trainable: {vision_model_keys}.") + forbidden_modules.update(vision_model_keys) + + if finetuning_args.freeze_multi_modal_projector: + projector_key = COMPOSITE_MODELS[model_type].projector_key + logger.info_rank0(f"Set multi model projector not trainable: {projector_key}.") + forbidden_modules.add(projector_key) if finetuning_args.train_mm_proj_only: - forbidden_modules.add("language_model") - - elif model_type == "mllama": - if finetuning_args.freeze_vision_tower: - forbidden_modules.add("vision_model") - - if finetuning_args.train_mm_proj_only: - forbidden_modules.add("language_model") - - elif model_type == "qwen2_vl": - if finetuning_args.train_mm_proj_only: - forbidden_modules.update({"visual.patch_embed", "visual.blocks", "model", "lm_head"}) - elif finetuning_args.freeze_vision_tower: - forbidden_modules.add("visual") + language_model_keys = COMPOSITE_MODELS[model_type].language_model_keys + logger.info_rank0(f"Set language model not trainable: {language_model_keys}.") + forbidden_modules.update(language_model_keys) return forbidden_modules @@ -190,18 +218,57 @@ def patch_target_modules( model_type = getattr(config, "model_type", None) vit_model_type = getattr(getattr(config, "vision_config", None), "model_type", None) if finetuning_args.freeze_vision_tower: - if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]: - return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules)) - elif model_type == "mllama": - return "^(?!.*vision_model).*(?:{}).*".format("|".join(target_modules)) - elif model_type == "qwen2_vl": - return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules)) + if model_type in COMPOSITE_MODELS: + vision_model_keys = COMPOSITE_MODELS[model_type].vision_model_keys + logger.info_rank0(f"Set vision model not trainable: {vision_model_keys}.") + vision_model_keys = "|".join(vision_model_keys) + target_modules = "|".join(target_modules) + return f"^(?!.*{vision_model_keys}).*(?:{target_modules}).*" else: return target_modules else: - if model_type == "qwen2_vl": + if model_type == "qwen2_vl": # avoid attaching lora to Conv3D layer return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules)) elif vit_model_type == "pixtral": return "^(?!.*patch_conv).*(?:{}).*".format("|".join(target_modules)) else: return target_modules + + +_register_composite_model( + model_type="llava", +) + + +_register_composite_model( + model_type="llava_next", +) + + +_register_composite_model( + model_type="llava_next_video", +) + + +_register_composite_model( + model_type="paligemma", +) + + +_register_composite_model( + model_type="video_llava", +) + + +_register_composite_model( + model_type="mllama", + vision_model_keys=["vision_model"], +) + + +_register_composite_model( + model_type="qwen2_vl", + projector_key="visual.merger", + vision_model_keys=["visual.patch_embed", "visual.blocks"], + language_model_keys=["model", "lm_head"], +) diff --git a/tests/data/test_template.py b/tests/data/test_template.py index cf260b47..3cf9227e 100644 --- a/tests/data/test_template.py +++ b/tests/data/test_template.py @@ -100,8 +100,7 @@ def test_encode_multiturn(use_fast: bool): ) answer_str_1 = "I am fine!<|eot_id|>" prompt_str_2 = ( - "<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>" - "<|start_header_id|>assistant<|end_header_id|>\n\n" + "<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" ) answer_str_2 = "很高兴认识你!<|eot_id|>" _check_tokenization( diff --git a/tests/model/model_utils/test_checkpointing.py b/tests/model/model_utils/test_checkpointing.py index 0b171508..cdf62807 100644 --- a/tests/model/model_utils/test_checkpointing.py +++ b/tests/model/model_utils/test_checkpointing.py @@ -14,6 +14,7 @@ import os +import pytest import torch from llamafactory.extras.misc import get_current_device @@ -39,16 +40,11 @@ TRAIN_ARGS = { } -def test_checkpointing_enable(): - model = load_train_model(disable_gradient_checkpointing=False, **TRAIN_ARGS) +@pytest.mark.parametrize("disable_gradient_checkpointing", [False, True]) +def test_vanilla_checkpointing(disable_gradient_checkpointing: bool): + model = load_train_model(disable_gradient_checkpointing=disable_gradient_checkpointing, **TRAIN_ARGS) for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()): - assert getattr(module, "gradient_checkpointing") is True - - -def test_checkpointing_disable(): - model = load_train_model(disable_gradient_checkpointing=True, **TRAIN_ARGS) - for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()): - assert getattr(module, "gradient_checkpointing") is False + assert getattr(module, "gradient_checkpointing") != disable_gradient_checkpointing def test_unsloth_gradient_checkpointing(): diff --git a/tests/model/model_utils/test_visual.py b/tests/model/model_utils/test_visual.py new file mode 100644 index 00000000..b4e23def --- /dev/null +++ b/tests/model/model_utils/test_visual.py @@ -0,0 +1,77 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from transformers import AutoConfig, AutoModelForVision2Seq + +from llamafactory.hparams import FinetuningArguments, ModelArguments +from llamafactory.model.adapter import init_adapter + + +@pytest.mark.parametrize( + "freeze_vision_tower,freeze_multi_modal_projector,train_mm_proj_only", + [ + (False, False, False), + (False, True, False), + (True, False, False), + (True, True, False), + (True, False, True), + ], +) +def test_visual_full(freeze_vision_tower: bool, freeze_multi_modal_projector: bool, train_mm_proj_only: bool): + model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct") + finetuning_args = FinetuningArguments( + finetuning_type="full", + freeze_vision_tower=freeze_vision_tower, + freeze_multi_modal_projector=freeze_multi_modal_projector, + train_mm_proj_only=train_mm_proj_only, + ) + config = AutoConfig.from_pretrained(model_args.model_name_or_path) + with torch.device("meta"): + model = AutoModelForVision2Seq.from_config(config) + + model = init_adapter(config, model, model_args, finetuning_args, is_trainable=True) + for name, param in model.named_parameters(): + if any(key in name for key in ["visual.patch_embed", "visual.blocks"]): + assert param.requires_grad != freeze_vision_tower + elif "visual.merger" in name: + assert param.requires_grad != freeze_multi_modal_projector + else: + assert param.requires_grad != train_mm_proj_only + + +@pytest.mark.parametrize("freeze_vision_tower", [False, True]) +def test_visual_lora(freeze_vision_tower: bool): + model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct") + finetuning_args = FinetuningArguments(finetuning_type="lora", freeze_vision_tower=freeze_vision_tower) + config = AutoConfig.from_pretrained(model_args.model_name_or_path) + with torch.device("meta"): + model = AutoModelForVision2Seq.from_config(config) + + model = init_adapter(config, model, model_args, finetuning_args, is_trainable=True) + trainable_params, frozen_params = set(), set() + for name, param in model.named_parameters(): + if param.requires_grad: + trainable_params.add(name) + else: + frozen_params.add(name) + + if freeze_vision_tower: + assert "base_model.model.visual.blocks.0.attn.qkv.lora_A.default.weight" not in trainable_params + else: + assert "base_model.model.visual.blocks.0.attn.qkv.lora_A.default.weight" in trainable_params + + assert "merger" not in trainable_params + assert "base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight" in trainable_params