mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
remove visual_inputs, fix qlora
Former-commit-id: a025c3df61db154bef13033518903bbf846f4fc8
This commit is contained in:
parent
51a0016873
commit
f31e7e0dfc
@ -1,3 +1,2 @@
|
|||||||
model_name_or_path: llava-hf/llava-1.5-7b-hf
|
model_name_or_path: llava-hf/llava-1.5-7b-hf
|
||||||
template: llava
|
template: llava
|
||||||
visual_inputs: true
|
|
||||||
|
@ -1,3 +1,2 @@
|
|||||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
||||||
template: qwen2_vl
|
template: qwen2_vl
|
||||||
visual_inputs: true
|
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
||||||
adapter_name_or_path: saves/qwen2_vl-7b/lora/sft
|
adapter_name_or_path: saves/qwen2_vl-7b/lora/sft
|
||||||
visual_inputs: true
|
|
||||||
template: qwen2_vl
|
template: qwen2_vl
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
||||||
visual_inputs: true
|
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
@ -9,7 +8,7 @@ finetuning_type: full
|
|||||||
deepspeed: examples/deepspeed/ds_z3_config.json
|
deepspeed: examples/deepspeed/ds_z3_config.json
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: mllm_demo
|
dataset: mllm_demo,identity
|
||||||
template: qwen2_vl
|
template: qwen2_vl
|
||||||
cutoff_len: 1024
|
cutoff_len: 1024
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: llava-hf/llava-1.5-7b-hf
|
model_name_or_path: llava-hf/llava-1.5-7b-hf
|
||||||
visual_inputs: true
|
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
||||||
visual_inputs: true
|
|
||||||
|
|
||||||
### method
|
### method
|
||||||
stage: sft
|
stage: sft
|
||||||
@ -9,7 +8,7 @@ finetuning_type: lora
|
|||||||
lora_target: all
|
lora_target: all
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
dataset: mllm_demo
|
dataset: mllm_demo,identity
|
||||||
template: qwen2_vl
|
template: qwen2_vl
|
||||||
cutoff_len: 1024
|
cutoff_len: 1024
|
||||||
max_samples: 1000
|
max_samples: 1000
|
||||||
|
@ -86,7 +86,7 @@ class VllmEngine(BaseEngine):
|
|||||||
"max_lora_rank": model_args.vllm_max_lora_rank,
|
"max_lora_rank": model_args.vllm_max_lora_rank,
|
||||||
}
|
}
|
||||||
|
|
||||||
if model_args.visual_inputs:
|
if getattr(config, "model_type", None) == "llava":
|
||||||
image_size = config.vision_config.image_size
|
image_size = config.vision_config.image_size
|
||||||
patch_size = config.vision_config.patch_size
|
patch_size = config.vision_config.patch_size
|
||||||
self.image_feature_size = (image_size // patch_size) ** 2
|
self.image_feature_size = (image_size // patch_size) ** 2
|
||||||
|
@ -16,15 +16,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
|
from typing import Any, Dict, Literal, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArguments:
|
class ModelArguments:
|
||||||
r"""
|
r"""
|
||||||
@ -121,10 +118,6 @@ class ModelArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to enable liger kernel for faster training."},
|
metadata={"help": "Whether or not to enable liger kernel for faster training."},
|
||||||
)
|
)
|
||||||
visual_inputs: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."},
|
|
||||||
)
|
|
||||||
moe_aux_loss_coef: Optional[float] = field(
|
moe_aux_loss_coef: Optional[float] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
|
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
|
||||||
@ -225,19 +218,31 @@ class ModelArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
|
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
|
||||||
)
|
)
|
||||||
|
compute_dtype: Optional[torch.dtype] = field(
|
||||||
|
default=None,
|
||||||
|
init=False,
|
||||||
|
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
|
||||||
|
)
|
||||||
|
device_map: Optional[Union[str, Dict[str, Any]]] = field(
|
||||||
|
default=None,
|
||||||
|
init=False,
|
||||||
|
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
|
||||||
|
)
|
||||||
|
model_max_length: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
init=False,
|
||||||
|
metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."},
|
||||||
|
)
|
||||||
|
block_diag_attn: bool = field(
|
||||||
|
default=False,
|
||||||
|
init=False,
|
||||||
|
metadata={"help": "Whether use block diag attention or not, derived from `neat_packing`. Do not specify it."},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.compute_dtype: Optional["torch.dtype"] = None
|
|
||||||
self.device_map: Optional[Union[str, Dict[str, Any]]] = None
|
|
||||||
self.model_max_length: Optional[int] = None
|
|
||||||
self.block_diag_attn: bool = False
|
|
||||||
|
|
||||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||||
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
||||||
|
|
||||||
if self.visual_inputs and self.use_unsloth:
|
|
||||||
raise ValueError("Unsloth does not support MLLM yet. Stay tuned.")
|
|
||||||
|
|
||||||
if self.adapter_name_or_path is not None: # support merging multiple lora weights
|
if self.adapter_name_or_path is not None: # support merging multiple lora weights
|
||||||
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
|
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
|
||||||
|
|
||||||
|
@ -257,9 +257,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
if model_args.infer_backend == "vllm":
|
if model_args.infer_backend == "vllm":
|
||||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||||
|
|
||||||
if model_args.visual_inputs and data_args.packing:
|
|
||||||
raise ValueError("Cannot use packing in MLLM fine-tuning.")
|
|
||||||
|
|
||||||
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
|
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
|
||||||
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
||||||
|
|
||||||
@ -388,9 +385,6 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
|||||||
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
||||||
raise ValueError("vLLM only accepts a single adapter. Merge them first.")
|
raise ValueError("vLLM only accepts a single adapter. Merge them first.")
|
||||||
|
|
||||||
if finetuning_args.stage == "rm" and model_args.visual_inputs:
|
|
||||||
raise ValueError("Reward server does not support MLLM yet. Stay tuned.")
|
|
||||||
|
|
||||||
_verify_model_args(model_args, data_args, finetuning_args)
|
_verify_model_args(model_args, data_args, finetuning_args)
|
||||||
_check_extra_dependencies(model_args, finetuning_args)
|
_check_extra_dependencies(model_args, finetuning_args)
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ from ..extras.logging import get_logger
|
|||||||
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
|
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
|
||||||
from .model_utils.quantization import QuantizationMethod
|
from .model_utils.quantization import QuantizationMethod
|
||||||
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
||||||
|
from .model_utils.visual import get_forbidden_modules, patch_target_modules
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -37,7 +38,6 @@ logger = get_logger(__name__)
|
|||||||
|
|
||||||
def _setup_full_tuning(
|
def _setup_full_tuning(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
model_args: "ModelArguments",
|
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
is_trainable: bool,
|
is_trainable: bool,
|
||||||
cast_trainable_params_to_fp32: bool,
|
cast_trainable_params_to_fp32: bool,
|
||||||
@ -46,13 +46,7 @@ def _setup_full_tuning(
|
|||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Fine-tuning method: Full")
|
logger.info("Fine-tuning method: Full")
|
||||||
forbidden_modules = set()
|
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
|
||||||
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
|
||||||
forbidden_modules.add("vision_tower")
|
|
||||||
|
|
||||||
if model_args.visual_inputs and finetuning_args.train_mm_proj_only:
|
|
||||||
forbidden_modules.add("language_model")
|
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
|
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
|
||||||
if cast_trainable_params_to_fp32:
|
if cast_trainable_params_to_fp32:
|
||||||
@ -63,7 +57,6 @@ def _setup_full_tuning(
|
|||||||
|
|
||||||
def _setup_freeze_tuning(
|
def _setup_freeze_tuning(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
model_args: "ModelArguments",
|
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
is_trainable: bool,
|
is_trainable: bool,
|
||||||
cast_trainable_params_to_fp32: bool,
|
cast_trainable_params_to_fp32: bool,
|
||||||
@ -72,8 +65,8 @@ def _setup_freeze_tuning(
|
|||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Fine-tuning method: Freeze")
|
logger.info("Fine-tuning method: Freeze")
|
||||||
if model_args.visual_inputs:
|
if hasattr(model.config, "text_config"): # composite models
|
||||||
config = model.config.text_config
|
config = getattr(model.config, "text_config")
|
||||||
else:
|
else:
|
||||||
config = model.config
|
config = model.config
|
||||||
|
|
||||||
@ -130,10 +123,7 @@ def _setup_freeze_tuning(
|
|||||||
|
|
||||||
trainable_layers.append(module_name)
|
trainable_layers.append(module_name)
|
||||||
|
|
||||||
forbidden_modules = set()
|
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
|
||||||
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
|
||||||
forbidden_modules.add("vision_tower")
|
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
|
if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
|
||||||
forbidden_module in name for forbidden_module in forbidden_modules
|
forbidden_module in name for forbidden_module in forbidden_modules
|
||||||
@ -211,8 +201,7 @@ def _setup_lora_tuning(
|
|||||||
if finetuning_args.use_llama_pro:
|
if finetuning_args.use_llama_pro:
|
||||||
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
|
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
|
||||||
|
|
||||||
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
target_modules = patch_target_modules(model.config, finetuning_args, target_modules)
|
||||||
target_modules = "^(?!.*(?:vision_tower|visual)).*(?:{}).*".format("|".join(target_modules))
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
finetuning_args.use_dora
|
finetuning_args.use_dora
|
||||||
@ -303,9 +292,9 @@ def init_adapter(
|
|||||||
cast_trainable_params_to_fp32 = True
|
cast_trainable_params_to_fp32 = True
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "full":
|
if finetuning_args.finetuning_type == "full":
|
||||||
_setup_full_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
|
_setup_full_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
|
||||||
elif finetuning_args.finetuning_type == "freeze":
|
elif finetuning_args.finetuning_type == "freeze":
|
||||||
_setup_freeze_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
|
_setup_freeze_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
|
||||||
elif finetuning_args.finetuning_type == "lora":
|
elif finetuning_args.finetuning_type == "lora":
|
||||||
model = _setup_lora_tuning(
|
model = _setup_lora_tuning(
|
||||||
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
|
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
|
||||||
|
@ -93,17 +93,10 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
|||||||
|
|
||||||
patch_tokenizer(tokenizer)
|
patch_tokenizer(tokenizer)
|
||||||
|
|
||||||
if model_args.visual_inputs:
|
try:
|
||||||
try:
|
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||||
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
setattr(processor, "tokenizer", tokenizer)
|
||||||
setattr(processor, "tokenizer", tokenizer)
|
except Exception:
|
||||||
except Exception:
|
|
||||||
raise ValueError(
|
|
||||||
"This multimodal LLM is not supported.\n"
|
|
||||||
"Download LLaVA-1.5 models from: https://huggingface.co/llava-hf\n"
|
|
||||||
"Download Yi-VL models from: https://huggingface.co/BUAADreamer"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
processor = None
|
processor = None
|
||||||
|
|
||||||
return {"tokenizer": tokenizer, "processor": processor}
|
return {"tokenizer": tokenizer, "processor": processor}
|
||||||
@ -145,12 +138,16 @@ def load_model(
|
|||||||
|
|
||||||
if model_args.mixture_of_depths == "load":
|
if model_args.mixture_of_depths == "load":
|
||||||
model = load_mod_pretrained_model(**init_kwargs)
|
model = load_mod_pretrained_model(**init_kwargs)
|
||||||
elif model_args.visual_inputs:
|
|
||||||
model = AutoModelForVision2Seq.from_pretrained(**init_kwargs)
|
|
||||||
elif model_args.train_from_scratch:
|
|
||||||
model = AutoModelForCausalLM.from_config(config)
|
|
||||||
else:
|
else:
|
||||||
model = AutoModelForCausalLM.from_pretrained(**init_kwargs)
|
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # assume built-in models
|
||||||
|
load_class = AutoModelForVision2Seq
|
||||||
|
else:
|
||||||
|
load_class = AutoModelForCausalLM
|
||||||
|
|
||||||
|
if model_args.train_from_scratch:
|
||||||
|
model = load_class.from_config(config)
|
||||||
|
else:
|
||||||
|
model = load_class.from_pretrained(**init_kwargs)
|
||||||
|
|
||||||
if model_args.mixture_of_depths == "convert":
|
if model_args.mixture_of_depths == "convert":
|
||||||
model = convert_pretrained_model_to_mod(model, config, model_args)
|
model = convert_pretrained_model_to_mod(model, config, model_args)
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Tuple
|
from typing import TYPE_CHECKING, List, Sequence, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers.models
|
import transformers.models
|
||||||
@ -28,7 +28,7 @@ from ...extras.logging import get_logger
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
|
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
|
||||||
|
|
||||||
from ...hparams import ModelArguments
|
from ...hparams import FinetuningArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@ -80,24 +80,74 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
|
|||||||
self.act = ACT2FN[projector_hidden_act]
|
self.act = ACT2FN[projector_hidden_act]
|
||||||
|
|
||||||
|
|
||||||
def autocast_projector_dtype(
|
def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
|
||||||
model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector"
|
r"""
|
||||||
) -> None:
|
Casts projector output to half precision for quantized VLMs.
|
||||||
|
"""
|
||||||
|
|
||||||
def _mm_projector_forward_post_hook(
|
def _mm_projector_forward_post_hook(
|
||||||
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
|
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
|
||||||
) -> "torch.Tensor":
|
) -> "torch.Tensor":
|
||||||
return output.to(model_args.compute_dtype)
|
return output.to(model_args.compute_dtype)
|
||||||
|
|
||||||
if hasattr(model, mm_projector_name) and getattr(model, "quantization_method", None):
|
if getattr(model, "quantization_method", None):
|
||||||
|
if getattr(model.config, "model_type", None) in ["llava", "paligemma"]:
|
||||||
|
mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector")
|
||||||
|
elif getattr(model.config, "model_type", None) == "qwen2_vl":
|
||||||
|
mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger")
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
|
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
|
||||||
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
|
|
||||||
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
||||||
|
|
||||||
|
|
||||||
def configure_visual_model(config: "PretrainedConfig") -> None:
|
def configure_visual_model(config: "PretrainedConfig") -> None:
|
||||||
|
r"""
|
||||||
|
Patches VLMs before loading them.
|
||||||
|
"""
|
||||||
if getattr(config, "model_type", None) == "llava": # required for ds zero3 and valuehead models
|
if getattr(config, "model_type", None) == "llava": # required for ds zero3 and valuehead models
|
||||||
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
|
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
|
||||||
|
|
||||||
if getattr(config, "is_yi_vl_derived_model", None):
|
if getattr(config, "is_yi_vl_derived_model", None):
|
||||||
logger.info("Detected Yi-VL model, applying projector patch.")
|
logger.info("Detected Yi-VL model, applying projector patch.")
|
||||||
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
|
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
|
||||||
|
|
||||||
|
|
||||||
|
def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> Set[str]:
|
||||||
|
r"""
|
||||||
|
Freezes vision tower and language model for VLM full/freeze tuning.
|
||||||
|
"""
|
||||||
|
forbidden_modules = set()
|
||||||
|
if getattr(config, "model_type", None) in ["llava", "paligemma"]:
|
||||||
|
if finetuning_args.freeze_vision_tower:
|
||||||
|
forbidden_modules.add("vision_tower")
|
||||||
|
|
||||||
|
if finetuning_args.train_mm_proj_only:
|
||||||
|
forbidden_modules.add("language_model")
|
||||||
|
|
||||||
|
elif getattr(config, "model_type", None) == "qwen2_vl":
|
||||||
|
if finetuning_args.freeze_vision_tower:
|
||||||
|
forbidden_modules.add("visual")
|
||||||
|
|
||||||
|
if finetuning_args.train_mm_proj_only:
|
||||||
|
raise ValueError("Qwen2-VL models do not support `train_mm_proj_only`.")
|
||||||
|
|
||||||
|
return forbidden_modules
|
||||||
|
|
||||||
|
|
||||||
|
def patch_target_modules(
|
||||||
|
config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
|
||||||
|
) -> Union[str, List[str]]:
|
||||||
|
r"""
|
||||||
|
Freezes vision tower for VLM LoRA tuning.
|
||||||
|
"""
|
||||||
|
if not finetuning_args.freeze_vision_tower:
|
||||||
|
return target_modules
|
||||||
|
|
||||||
|
if getattr(config, "model_type", None) in ["llava", "paligemma"]:
|
||||||
|
return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
|
||||||
|
elif getattr(config, "model_type", None) == "qwen2_vl":
|
||||||
|
return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
|
||||||
|
else:
|
||||||
|
return target_modules
|
||||||
|
@ -131,11 +131,9 @@ def patch_model(
|
|||||||
if model_args.resize_vocab:
|
if model_args.resize_vocab:
|
||||||
resize_embedding_layer(model, tokenizer)
|
resize_embedding_layer(model, tokenizer)
|
||||||
|
|
||||||
if model_args.visual_inputs:
|
|
||||||
autocast_projector_dtype(model, model_args)
|
|
||||||
|
|
||||||
if is_trainable:
|
if is_trainable:
|
||||||
prepare_model_for_training(model, model_args)
|
prepare_model_for_training(model, model_args)
|
||||||
|
autocast_projector_dtype(model, model_args)
|
||||||
add_z3_leaf_module(model)
|
add_z3_leaf_module(model)
|
||||||
|
|
||||||
if not model_args.use_unsloth:
|
if not model_args.use_unsloth:
|
||||||
|
@ -61,7 +61,6 @@ def run_sft(
|
|||||||
# Override the decoding parameters of Seq2SeqTrainer
|
# Override the decoding parameters of Seq2SeqTrainer
|
||||||
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
|
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
|
||||||
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
||||||
training_args.remove_unused_columns = False if model_args.visual_inputs else training_args.remove_unused_columns
|
|
||||||
|
|
||||||
# Metric utils
|
# Metric utils
|
||||||
metric_module = {}
|
metric_module = {}
|
||||||
|
@ -132,7 +132,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
|||||||
if model_args.export_hub_model_id is not None:
|
if model_args.export_hub_model_id is not None:
|
||||||
tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
|
tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
|
||||||
|
|
||||||
if model_args.visual_inputs and processor is not None:
|
if processor is not None:
|
||||||
getattr(processor, "image_processor").save_pretrained(model_args.export_dir)
|
getattr(processor, "image_processor").save_pretrained(model_args.export_dir)
|
||||||
if model_args.export_hub_model_id is not None:
|
if model_args.export_hub_model_id is not None:
|
||||||
getattr(processor, "image_processor").push_to_hub(
|
getattr(processor, "image_processor").push_to_hub(
|
||||||
|
@ -90,7 +90,6 @@ class WebChatModel(ChatModel):
|
|||||||
template=get("top.template"),
|
template=get("top.template"),
|
||||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||||
use_unsloth=(get("top.booster") == "unsloth"),
|
use_unsloth=(get("top.booster") == "unsloth"),
|
||||||
visual_inputs=get("top.visual_inputs"),
|
|
||||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||||
infer_backend=get("infer.infer_backend"),
|
infer_backend=get("infer.infer_backend"),
|
||||||
infer_dtype=get("infer.infer_dtype"),
|
infer_dtype=get("infer.infer_dtype"),
|
||||||
|
@ -122,16 +122,15 @@ def get_prefix(model_name: str) -> str:
|
|||||||
return model_name.split("-")[0]
|
return model_name.split("-")[0]
|
||||||
|
|
||||||
|
|
||||||
def get_model_info(model_name: str) -> Tuple[str, str, bool]:
|
def get_model_info(model_name: str) -> Tuple[str, str]:
|
||||||
r"""
|
r"""
|
||||||
Gets the necessary information of this model.
|
Gets the necessary information of this model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
model_path (str)
|
model_path (str)
|
||||||
template (str)
|
template (str)
|
||||||
visual (bool)
|
|
||||||
"""
|
"""
|
||||||
return get_model_path(model_name), get_template(model_name), get_visual(model_name)
|
return get_model_path(model_name), get_template(model_name)
|
||||||
|
|
||||||
|
|
||||||
def get_template(model_name: str) -> str:
|
def get_template(model_name: str) -> str:
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING, Dict
|
||||||
|
|
||||||
from ...extras.packages import is_gradio_available
|
from ...extras.packages import is_gradio_available
|
||||||
|
from ..common import get_visual
|
||||||
from .chatbot import create_chat_box
|
from .chatbot import create_chat_box
|
||||||
|
|
||||||
|
|
||||||
@ -64,9 +65,9 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
lambda: ([], []), outputs=[chatbot, messages]
|
lambda: ([], []), outputs=[chatbot, messages]
|
||||||
).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]])
|
).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]])
|
||||||
|
|
||||||
engine.manager.get_elem_by_id("top.visual_inputs").change(
|
engine.manager.get_elem_by_id("top.model_name").change(
|
||||||
lambda enabled: gr.Column(visible=enabled),
|
lambda model_name: gr.Column(visible=get_visual(model_name)),
|
||||||
[engine.manager.get_elem_by_id("top.visual_inputs")],
|
[engine.manager.get_elem_by_id("top.model_name")],
|
||||||
[chat_elems["image_box"]],
|
[chat_elems["image_box"]],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -48,9 +48,8 @@ def create_top() -> Dict[str, "Component"]:
|
|||||||
template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=1)
|
template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=1)
|
||||||
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=2)
|
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=2)
|
||||||
booster = gr.Radio(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto", scale=3)
|
booster = gr.Radio(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto", scale=3)
|
||||||
visual_inputs = gr.Checkbox(scale=1)
|
|
||||||
|
|
||||||
model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False).then(
|
model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then(
|
||||||
list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
|
list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
|
||||||
)
|
)
|
||||||
model_name.input(save_config, inputs=[lang, model_name], queue=False)
|
model_name.input(save_config, inputs=[lang, model_name], queue=False)
|
||||||
@ -73,5 +72,4 @@ def create_top() -> Dict[str, "Component"]:
|
|||||||
template=template,
|
template=template,
|
||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
booster=booster,
|
booster=booster,
|
||||||
visual_inputs=visual_inputs,
|
|
||||||
)
|
)
|
||||||
|
@ -183,20 +183,6 @@ LOCALES = {
|
|||||||
"label": "부스터",
|
"label": "부스터",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"visual_inputs": {
|
|
||||||
"en": {
|
|
||||||
"label": "Visual inputs",
|
|
||||||
},
|
|
||||||
"ru": {
|
|
||||||
"label": "визуальные входы",
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "图像输入",
|
|
||||||
},
|
|
||||||
"ko": {
|
|
||||||
"label": "시각적 입력",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"training_stage": {
|
"training_stage": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Stage",
|
"label": "Stage",
|
||||||
|
@ -75,5 +75,4 @@ class Manager:
|
|||||||
self._id_to_elem["top.template"],
|
self._id_to_elem["top.template"],
|
||||||
self._id_to_elem["top.rope_scaling"],
|
self._id_to_elem["top.rope_scaling"],
|
||||||
self._id_to_elem["top.booster"],
|
self._id_to_elem["top.booster"],
|
||||||
self._id_to_elem["top.visual_inputs"],
|
|
||||||
}
|
}
|
||||||
|
@ -116,7 +116,6 @@ class Runner:
|
|||||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||||
use_unsloth=(get("top.booster") == "unsloth"),
|
use_unsloth=(get("top.booster") == "unsloth"),
|
||||||
enable_liger_kernel=(get("top.booster") == "liger_kernel"),
|
enable_liger_kernel=(get("top.booster") == "liger_kernel"),
|
||||||
visual_inputs=get("top.visual_inputs"),
|
|
||||||
dataset_dir=get("train.dataset_dir"),
|
dataset_dir=get("train.dataset_dir"),
|
||||||
dataset=",".join(get("train.dataset")),
|
dataset=",".join(get("train.dataset")),
|
||||||
cutoff_len=get("train.cutoff_len"),
|
cutoff_len=get("train.cutoff_len"),
|
||||||
@ -252,7 +251,6 @@ class Runner:
|
|||||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||||
use_unsloth=(get("top.booster") == "unsloth"),
|
use_unsloth=(get("top.booster") == "unsloth"),
|
||||||
visual_inputs=get("top.visual_inputs"),
|
|
||||||
dataset_dir=get("eval.dataset_dir"),
|
dataset_dir=get("eval.dataset_dir"),
|
||||||
eval_dataset=",".join(get("eval.dataset")),
|
eval_dataset=",".join(get("eval.dataset")),
|
||||||
cutoff_len=get("eval.cutoff_len"),
|
cutoff_len=get("eval.cutoff_len"),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user