remove visual_inputs, fix qlora

Former-commit-id: a025c3df61db154bef13033518903bbf846f4fc8
This commit is contained in:
hiyouga 2024-08-31 00:24:51 +08:00
parent 51a0016873
commit f31e7e0dfc
22 changed files with 112 additions and 106 deletions

View File

@ -1,3 +1,2 @@
model_name_or_path: llava-hf/llava-1.5-7b-hf
template: llava
visual_inputs: true

View File

@ -1,3 +1,2 @@
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
template: qwen2_vl
visual_inputs: true

View File

@ -3,7 +3,6 @@
### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
adapter_name_or_path: saves/qwen2_vl-7b/lora/sft
visual_inputs: true
template: qwen2_vl
finetuning_type: lora

View File

@ -1,6 +1,5 @@
### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
visual_inputs: true
### method
stage: sft
@ -9,7 +8,7 @@ finetuning_type: full
deepspeed: examples/deepspeed/ds_z3_config.json
### dataset
dataset: mllm_demo
dataset: mllm_demo,identity
template: qwen2_vl
cutoff_len: 1024
max_samples: 1000

View File

@ -1,6 +1,5 @@
### model
model_name_or_path: llava-hf/llava-1.5-7b-hf
visual_inputs: true
### method
stage: sft

View File

@ -1,6 +1,5 @@
### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
visual_inputs: true
### method
stage: sft
@ -9,7 +8,7 @@ finetuning_type: lora
lora_target: all
### dataset
dataset: mllm_demo
dataset: mllm_demo,identity
template: qwen2_vl
cutoff_len: 1024
max_samples: 1000

View File

@ -86,7 +86,7 @@ class VllmEngine(BaseEngine):
"max_lora_rank": model_args.vllm_max_lora_rank,
}
if model_args.visual_inputs:
if getattr(config, "model_type", None) == "llava":
image_size = config.vision_config.image_size
patch_size = config.vision_config.patch_size
self.image_feature_size = (image_size // patch_size) ** 2

View File

@ -16,15 +16,12 @@
# limitations under the License.
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
from typing import Any, Dict, Literal, Optional, Union
import torch
from typing_extensions import Self
if TYPE_CHECKING:
import torch
@dataclass
class ModelArguments:
r"""
@ -121,10 +118,6 @@ class ModelArguments:
default=False,
metadata={"help": "Whether or not to enable liger kernel for faster training."},
)
visual_inputs: bool = field(
default=False,
metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."},
)
moe_aux_loss_coef: Optional[float] = field(
default=None,
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
@ -225,19 +218,31 @@ class ModelArguments:
default=False,
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
)
compute_dtype: Optional[torch.dtype] = field(
default=None,
init=False,
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
)
device_map: Optional[Union[str, Dict[str, Any]]] = field(
default=None,
init=False,
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
)
model_max_length: Optional[int] = field(
default=None,
init=False,
metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."},
)
block_diag_attn: bool = field(
default=False,
init=False,
metadata={"help": "Whether use block diag attention or not, derived from `neat_packing`. Do not specify it."},
)
def __post_init__(self):
self.compute_dtype: Optional["torch.dtype"] = None
self.device_map: Optional[Union[str, Dict[str, Any]]] = None
self.model_max_length: Optional[int] = None
self.block_diag_attn: bool = False
if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
if self.visual_inputs and self.use_unsloth:
raise ValueError("Unsloth does not support MLLM yet. Stay tuned.")
if self.adapter_name_or_path is not None: # support merging multiple lora weights
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]

View File

@ -257,9 +257,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")
if model_args.visual_inputs and data_args.packing:
raise ValueError("Cannot use packing in MLLM fine-tuning.")
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
@ -388,9 +385,6 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
raise ValueError("vLLM only accepts a single adapter. Merge them first.")
if finetuning_args.stage == "rm" and model_args.visual_inputs:
raise ValueError("Reward server does not support MLLM yet. Stay tuned.")
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)

View File

@ -24,6 +24,7 @@ from ..extras.logging import get_logger
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
from .model_utils.visual import get_forbidden_modules, patch_target_modules
if TYPE_CHECKING:
@ -37,7 +38,6 @@ logger = get_logger(__name__)
def _setup_full_tuning(
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
cast_trainable_params_to_fp32: bool,
@ -46,13 +46,7 @@ def _setup_full_tuning(
return
logger.info("Fine-tuning method: Full")
forbidden_modules = set()
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower")
if model_args.visual_inputs and finetuning_args.train_mm_proj_only:
forbidden_modules.add("language_model")
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
for name, param in model.named_parameters():
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
if cast_trainable_params_to_fp32:
@ -63,7 +57,6 @@ def _setup_full_tuning(
def _setup_freeze_tuning(
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
cast_trainable_params_to_fp32: bool,
@ -72,8 +65,8 @@ def _setup_freeze_tuning(
return
logger.info("Fine-tuning method: Freeze")
if model_args.visual_inputs:
config = model.config.text_config
if hasattr(model.config, "text_config"): # composite models
config = getattr(model.config, "text_config")
else:
config = model.config
@ -130,10 +123,7 @@ def _setup_freeze_tuning(
trainable_layers.append(module_name)
forbidden_modules = set()
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower")
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
forbidden_module in name for forbidden_module in forbidden_modules
@ -211,8 +201,7 @@ def _setup_lora_tuning(
if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
target_modules = "^(?!.*(?:vision_tower|visual)).*(?:{}).*".format("|".join(target_modules))
target_modules = patch_target_modules(model.config, finetuning_args, target_modules)
if (
finetuning_args.use_dora
@ -303,9 +292,9 @@ def init_adapter(
cast_trainable_params_to_fp32 = True
if finetuning_args.finetuning_type == "full":
_setup_full_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
_setup_full_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
elif finetuning_args.finetuning_type == "freeze":
_setup_freeze_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
_setup_freeze_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
elif finetuning_args.finetuning_type == "lora":
model = _setup_lora_tuning(
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32

View File

@ -93,17 +93,10 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
patch_tokenizer(tokenizer)
if model_args.visual_inputs:
try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
setattr(processor, "tokenizer", tokenizer)
except Exception:
raise ValueError(
"This multimodal LLM is not supported.\n"
"Download LLaVA-1.5 models from: https://huggingface.co/llava-hf\n"
"Download Yi-VL models from: https://huggingface.co/BUAADreamer"
)
else:
try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
setattr(processor, "tokenizer", tokenizer)
except Exception:
processor = None
return {"tokenizer": tokenizer, "processor": processor}
@ -145,12 +138,16 @@ def load_model(
if model_args.mixture_of_depths == "load":
model = load_mod_pretrained_model(**init_kwargs)
elif model_args.visual_inputs:
model = AutoModelForVision2Seq.from_pretrained(**init_kwargs)
elif model_args.train_from_scratch:
model = AutoModelForCausalLM.from_config(config)
else:
model = AutoModelForCausalLM.from_pretrained(**init_kwargs)
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # assume built-in models
load_class = AutoModelForVision2Seq
else:
load_class = AutoModelForCausalLM
if model_args.train_from_scratch:
model = load_class.from_config(config)
else:
model = load_class.from_pretrained(**init_kwargs)
if model_args.mixture_of_depths == "convert":
model = convert_pretrained_model_to_mod(model, config, model_args)

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, List, Sequence, Set, Tuple, Union
import torch
import transformers.models
@ -28,7 +28,7 @@ from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
from ...hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
@ -80,24 +80,74 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
self.act = ACT2FN[projector_hidden_act]
def autocast_projector_dtype(
model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector"
) -> None:
def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
r"""
Casts projector output to half precision for quantized VLMs.
"""
def _mm_projector_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor":
return output.to(model_args.compute_dtype)
if hasattr(model, mm_projector_name) and getattr(model, "quantization_method", None):
if getattr(model, "quantization_method", None):
if getattr(model.config, "model_type", None) in ["llava", "paligemma"]:
mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector")
elif getattr(model.config, "model_type", None) == "qwen2_vl":
mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger")
else:
return
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
def configure_visual_model(config: "PretrainedConfig") -> None:
r"""
Patches VLMs before loading them.
"""
if getattr(config, "model_type", None) == "llava": # required for ds zero3 and valuehead models
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
if getattr(config, "is_yi_vl_derived_model", None):
logger.info("Detected Yi-VL model, applying projector patch.")
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> Set[str]:
r"""
Freezes vision tower and language model for VLM full/freeze tuning.
"""
forbidden_modules = set()
if getattr(config, "model_type", None) in ["llava", "paligemma"]:
if finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower")
if finetuning_args.train_mm_proj_only:
forbidden_modules.add("language_model")
elif getattr(config, "model_type", None) == "qwen2_vl":
if finetuning_args.freeze_vision_tower:
forbidden_modules.add("visual")
if finetuning_args.train_mm_proj_only:
raise ValueError("Qwen2-VL models do not support `train_mm_proj_only`.")
return forbidden_modules
def patch_target_modules(
config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
) -> Union[str, List[str]]:
r"""
Freezes vision tower for VLM LoRA tuning.
"""
if not finetuning_args.freeze_vision_tower:
return target_modules
if getattr(config, "model_type", None) in ["llava", "paligemma"]:
return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
elif getattr(config, "model_type", None) == "qwen2_vl":
return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
else:
return target_modules

View File

@ -131,11 +131,9 @@ def patch_model(
if model_args.resize_vocab:
resize_embedding_layer(model, tokenizer)
if model_args.visual_inputs:
autocast_projector_dtype(model, model_args)
if is_trainable:
prepare_model_for_training(model, model_args)
autocast_projector_dtype(model, model_args)
add_z3_leaf_module(model)
if not model_args.use_unsloth:

View File

@ -61,7 +61,6 @@ def run_sft(
# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
training_args.remove_unused_columns = False if model_args.visual_inputs else training_args.remove_unused_columns
# Metric utils
metric_module = {}

View File

@ -132,7 +132,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
if model_args.export_hub_model_id is not None:
tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
if model_args.visual_inputs and processor is not None:
if processor is not None:
getattr(processor, "image_processor").save_pretrained(model_args.export_dir)
if model_args.export_hub_model_id is not None:
getattr(processor, "image_processor").push_to_hub(

View File

@ -90,7 +90,6 @@ class WebChatModel(ChatModel):
template=get("top.template"),
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
visual_inputs=get("top.visual_inputs"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
infer_backend=get("infer.infer_backend"),
infer_dtype=get("infer.infer_dtype"),

View File

@ -122,16 +122,15 @@ def get_prefix(model_name: str) -> str:
return model_name.split("-")[0]
def get_model_info(model_name: str) -> Tuple[str, str, bool]:
def get_model_info(model_name: str) -> Tuple[str, str]:
r"""
Gets the necessary information of this model.
Returns:
model_path (str)
template (str)
visual (bool)
"""
return get_model_path(model_name), get_template(model_name), get_visual(model_name)
return get_model_path(model_name), get_template(model_name)
def get_template(model_name: str) -> str:

View File

@ -15,6 +15,7 @@
from typing import TYPE_CHECKING, Dict
from ...extras.packages import is_gradio_available
from ..common import get_visual
from .chatbot import create_chat_box
@ -64,9 +65,9 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
lambda: ([], []), outputs=[chatbot, messages]
).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]])
engine.manager.get_elem_by_id("top.visual_inputs").change(
lambda enabled: gr.Column(visible=enabled),
[engine.manager.get_elem_by_id("top.visual_inputs")],
engine.manager.get_elem_by_id("top.model_name").change(
lambda model_name: gr.Column(visible=get_visual(model_name)),
[engine.manager.get_elem_by_id("top.model_name")],
[chat_elems["image_box"]],
)

View File

@ -48,9 +48,8 @@ def create_top() -> Dict[str, "Component"]:
template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=1)
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=2)
booster = gr.Radio(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto", scale=3)
visual_inputs = gr.Checkbox(scale=1)
model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False).then(
model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then(
list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
)
model_name.input(save_config, inputs=[lang, model_name], queue=False)
@ -73,5 +72,4 @@ def create_top() -> Dict[str, "Component"]:
template=template,
rope_scaling=rope_scaling,
booster=booster,
visual_inputs=visual_inputs,
)

View File

@ -183,20 +183,6 @@ LOCALES = {
"label": "부스터",
},
},
"visual_inputs": {
"en": {
"label": "Visual inputs",
},
"ru": {
"label": "визуальные входы",
},
"zh": {
"label": "图像输入",
},
"ko": {
"label": "시각적 입력",
},
},
"training_stage": {
"en": {
"label": "Stage",

View File

@ -75,5 +75,4 @@ class Manager:
self._id_to_elem["top.template"],
self._id_to_elem["top.rope_scaling"],
self._id_to_elem["top.booster"],
self._id_to_elem["top.visual_inputs"],
}

View File

@ -116,7 +116,6 @@ class Runner:
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
enable_liger_kernel=(get("top.booster") == "liger_kernel"),
visual_inputs=get("top.visual_inputs"),
dataset_dir=get("train.dataset_dir"),
dataset=",".join(get("train.dataset")),
cutoff_len=get("train.cutoff_len"),
@ -252,7 +251,6 @@ class Runner:
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
visual_inputs=get("top.visual_inputs"),
dataset_dir=get("eval.dataset_dir"),
eval_dataset=",".join(get("eval.dataset")),
cutoff_len=get("eval.cutoff_len"),