From 6e98872622c90220fb461b9733f5484144401328 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Mon, 2 Sep 2024 23:56:21 +0800 Subject: [PATCH] fix #5324 Former-commit-id: a61c8c4890962f3847b19eff31b170cd7f54316c --- .../model/model_utils/liger_kernel.py | 15 +++---- src/llamafactory/model/model_utils/misc.py | 12 +++--- src/llamafactory/model/model_utils/moe.py | 20 +++++----- src/llamafactory/model/model_utils/visual.py | 40 +++++++++++-------- src/llamafactory/train/sft/metric.py | 8 ++++ 5 files changed, 57 insertions(+), 38 deletions(-) diff --git a/src/llamafactory/model/model_utils/liger_kernel.py b/src/llamafactory/model/model_utils/liger_kernel.py index 31edd97c..81c1132d 100644 --- a/src/llamafactory/model/model_utils/liger_kernel.py +++ b/src/llamafactory/model/model_utils/liger_kernel.py @@ -30,19 +30,20 @@ def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArgumen if not is_trainable or not model_args.enable_liger_kernel: return - if getattr(config, "model_type", None) == "gemma": + model_type = getattr(config, "model_type", None) + if model_type == "gemma": from liger_kernel.transformers import apply_liger_kernel_to_gemma as apply_liger_kernel - elif getattr(config, "model_type", None) == "gemma2": + elif model_type == "gemma2": from liger_kernel.transformers import apply_liger_kernel_to_gemma2 as apply_liger_kernel - elif getattr(config, "model_type", None) == "llama": + elif model_type == "llama": from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel - elif getattr(config, "model_type", None) == "mistral": + elif model_type == "mistral": from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel - elif getattr(config, "model_type", None) == "mixtral": + elif model_type == "mixtral": from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel - elif getattr(config, "model_type", None) == "phi3": + elif model_type == "phi3": from liger_kernel.transformers import apply_liger_kernel_to_phi3 as apply_liger_kernel - elif getattr(config, "model_type", None) == "qwen2": + elif model_type == "qwen2": from liger_kernel.transformers import apply_liger_kernel_to_qwen2 as apply_liger_kernel else: logger.warning("Current model does not support liger kernel.") diff --git a/src/llamafactory/model/model_utils/misc.py b/src/llamafactory/model/model_utils/misc.py index d49222a3..342a1008 100644 --- a/src/llamafactory/model/model_utils/misc.py +++ b/src/llamafactory/model/model_utils/misc.py @@ -28,19 +28,19 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) r""" Finds all available modules to apply lora or galore. """ + model_type = getattr(model.config, "model_type", None) forbidden_modules = {"lm_head"} - - if model.config.model_type == "chatglm": + if model_type == "chatglm": forbidden_modules.add("output_layer") - elif model.config.model_type == "internlm2": + elif model_type == "internlm2": forbidden_modules.add("output") - elif model.config.model_type in ["llava", "paligemma"]: + elif model_type in ["llava", "paligemma"]: forbidden_modules.add("multi_modal_projector") - elif model.config.model_type == "qwen2_vl": + elif model_type == "qwen2_vl": forbidden_modules.add("merger") if freeze_vision_tower: - if model.config.model_type == "qwen2_vl": + if model_type == "qwen2_vl": forbidden_modules.add("visual") else: forbidden_modules.add("vision_tower") diff --git a/src/llamafactory/model/model_utils/moe.py b/src/llamafactory/model/model_utils/moe.py index 5c7473aa..642d164a 100644 --- a/src/llamafactory/model/model_utils/moe.py +++ b/src/llamafactory/model/model_utils/moe.py @@ -39,42 +39,44 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None: if not is_deepspeed_zero3_enabled(): return - if getattr(model.config, "model_type", None) == "dbrx": + model_type = getattr(model.config, "model_type", None) + if model_type == "dbrx": from transformers.models.dbrx.modeling_dbrx import DbrxFFN _set_z3_leaf_modules(model, [DbrxFFN]) - if getattr(model.config, "model_type", None) == "jamba": + if model_type == "jamba": from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock _set_z3_leaf_modules(model, [JambaSparseMoeBlock]) - if getattr(model.config, "model_type", None) == "jetmoe": + if model_type == "jetmoe": from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE _set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE]) - if getattr(model.config, "model_type", None) == "mixtral": + if model_type == "mixtral": from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock _set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) - if getattr(model.config, "model_type", None) == "qwen2moe": + if model_type == "qwen2moe": from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock _set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock]) def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: + model_type = getattr(config, "model_type", None) if model_args.moe_aux_loss_coef is not None: - if getattr(config, "model_type", None) in ["jamba", "mixtral", "qwen2_moe"]: + if model_type in ["jamba", "mixtral", "qwen2_moe"]: setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef) - elif getattr(config, "model_type", None) == "deepseek": + elif model_type == "deepseek": setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef) - elif getattr(config, "model_type", None) == "jetmoe": + elif model_type == "jetmoe": setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef) - if getattr(config, "model_type", None) in ["dbrx", "jamba", "jetmoe", "mixtral", "qwen2_moe"]: + if model_type in ["dbrx", "jamba", "jetmoe", "mixtral", "qwen2_moe"]: setattr(config, "output_router_logits", is_trainable) diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 1fbf3400..23f880a6 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -91,9 +91,10 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen return output.to(model_args.compute_dtype) if getattr(model, "quantization_method", None): - if getattr(model.config, "model_type", None) in ["llava", "paligemma"]: + model_type = getattr(model.config, "model_type", None) + if model_type in ["llava", "paligemma"]: mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector") - elif getattr(model.config, "model_type", None) == "qwen2_vl": + elif model_type == "qwen2_vl": mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger") else: return @@ -106,7 +107,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None: r""" Patches VLMs before loading them. """ - if getattr(config, "model_type", None) == "llava": # required for ds zero3 and valuehead models + model_type = getattr(config, "model_type", None) + if model_type == "llava": # required for ds zero3 and valuehead models setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None)) if getattr(config, "is_yi_vl_derived_model", None): @@ -118,15 +120,16 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni r""" Freezes vision tower and language model for VLM full/freeze tuning. """ + model_type = getattr(config, "model_type", None) forbidden_modules = set() - if getattr(config, "model_type", None) in ["llava", "paligemma"]: + if model_type in ["llava", "paligemma"]: if finetuning_args.freeze_vision_tower: forbidden_modules.add("vision_tower") if finetuning_args.train_mm_proj_only: forbidden_modules.add("language_model") - elif getattr(config, "model_type", None) == "qwen2_vl": + elif model_type == "qwen2_vl": if finetuning_args.freeze_vision_tower: forbidden_modules.add("visual") @@ -140,13 +143,14 @@ def get_image_seqlen(config: "PretrainedConfig") -> int: r""" Computes the number of special tokens per image. """ - if getattr(config, "model_type", None) == "llava": + model_type = getattr(config, "model_type", None) + if model_type == "llava": image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2 if getattr(config, "vision_feature_select_strategy", "default") == "full": # add [CLS] token image_seqlen += 1 - elif getattr(config, "model_type", None) == "paligemma": + elif model_type == "paligemma": image_seqlen = config.vision_config.num_image_tokens - elif getattr(config, "model_type", None) == "qwen2_vl": # variable length + elif model_type == "qwen2_vl": # variable length image_seqlen = -1 return image_seqlen @@ -158,12 +162,16 @@ def patch_target_modules( r""" Freezes vision tower for VLM LoRA tuning. """ - if not finetuning_args.freeze_vision_tower: - return target_modules - - if getattr(config, "model_type", None) in ["llava", "paligemma"]: - return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules)) - elif getattr(config, "model_type", None) == "qwen2_vl": - return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules)) + model_type = getattr(config, "model_type", None) + if finetuning_args.freeze_vision_tower: + if model_type in ["llava", "paligemma"]: + return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules)) + elif model_type == "qwen2_vl": + return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules)) + else: + return target_modules else: - return target_modules + if model_type == "qwen2_vl": + return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules)) + else: + return target_modules diff --git a/src/llamafactory/train/sft/metric.py b/src/llamafactory/train/sft/metric.py index 69327379..93610290 100644 --- a/src/llamafactory/train/sft/metric.py +++ b/src/llamafactory/train/sft/metric.py @@ -45,6 +45,9 @@ if is_rouge_available(): def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor": + r""" + Computes the token with the largest likelihood to reduce memory footprint. + """ if isinstance(logits, (list, tuple)): if logits[0].dim() == 3: # (batch_size, seq_len, vocab_size) logits = logits[0] @@ -59,6 +62,9 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor @dataclass class ComputeAccuracy: + r""" + Computes accuracy and supports `batch_eval_metrics`. + """ def _dump(self) -> Optional[Dict[str, float]]: result = None if hasattr(self, "score_dict"): @@ -84,6 +90,8 @@ class ComputeAccuracy: @dataclass class ComputeSimilarity: r""" + Computes text similarity scores and supports `batch_eval_metrics`. + Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer. """