mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	remove visual_inputs, fix qlora
Former-commit-id: be30c01c4f1482520ece770bd54c6a4837c26f0a
This commit is contained in:
		
							parent
							
								
									d789b667d7
								
							
						
					
					
						commit
						2f6fc27c8b
					
				@ -1,3 +1,2 @@
 | 
			
		||||
model_name_or_path: llava-hf/llava-1.5-7b-hf
 | 
			
		||||
template: llava
 | 
			
		||||
visual_inputs: true
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,2 @@
 | 
			
		||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
 | 
			
		||||
template: qwen2_vl
 | 
			
		||||
visual_inputs: true
 | 
			
		||||
 | 
			
		||||
@ -3,7 +3,6 @@
 | 
			
		||||
### model
 | 
			
		||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
 | 
			
		||||
adapter_name_or_path: saves/qwen2_vl-7b/lora/sft
 | 
			
		||||
visual_inputs: true
 | 
			
		||||
template: qwen2_vl
 | 
			
		||||
finetuning_type: lora
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,5 @@
 | 
			
		||||
### model
 | 
			
		||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
 | 
			
		||||
visual_inputs: true
 | 
			
		||||
 | 
			
		||||
### method
 | 
			
		||||
stage: sft
 | 
			
		||||
@ -9,7 +8,7 @@ finetuning_type: full
 | 
			
		||||
deepspeed: examples/deepspeed/ds_z3_config.json
 | 
			
		||||
 | 
			
		||||
### dataset
 | 
			
		||||
dataset: mllm_demo
 | 
			
		||||
dataset: mllm_demo,identity
 | 
			
		||||
template: qwen2_vl
 | 
			
		||||
cutoff_len: 1024
 | 
			
		||||
max_samples: 1000
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,5 @@
 | 
			
		||||
### model
 | 
			
		||||
model_name_or_path: llava-hf/llava-1.5-7b-hf
 | 
			
		||||
visual_inputs: true
 | 
			
		||||
 | 
			
		||||
### method
 | 
			
		||||
stage: sft
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,5 @@
 | 
			
		||||
### model
 | 
			
		||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
 | 
			
		||||
visual_inputs: true
 | 
			
		||||
 | 
			
		||||
### method
 | 
			
		||||
stage: sft
 | 
			
		||||
@ -9,7 +8,7 @@ finetuning_type: lora
 | 
			
		||||
lora_target: all
 | 
			
		||||
 | 
			
		||||
### dataset
 | 
			
		||||
dataset: mllm_demo
 | 
			
		||||
dataset: mllm_demo,identity
 | 
			
		||||
template: qwen2_vl
 | 
			
		||||
cutoff_len: 1024
 | 
			
		||||
max_samples: 1000
 | 
			
		||||
 | 
			
		||||
@ -86,7 +86,7 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
            "max_lora_rank": model_args.vllm_max_lora_rank,
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if model_args.visual_inputs:
 | 
			
		||||
        if getattr(config, "model_type", None) == "llava":
 | 
			
		||||
            image_size = config.vision_config.image_size
 | 
			
		||||
            patch_size = config.vision_config.patch_size
 | 
			
		||||
            self.image_feature_size = (image_size // patch_size) ** 2
 | 
			
		||||
 | 
			
		||||
@ -16,15 +16,12 @@
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from dataclasses import asdict, dataclass, field
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
 | 
			
		||||
from typing import Any, Dict, Literal, Optional, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from typing_extensions import Self
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    import torch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class ModelArguments:
 | 
			
		||||
    r"""
 | 
			
		||||
@ -121,10 +118,6 @@ class ModelArguments:
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether or not to enable liger kernel for faster training."},
 | 
			
		||||
    )
 | 
			
		||||
    visual_inputs: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."},
 | 
			
		||||
    )
 | 
			
		||||
    moe_aux_loss_coef: Optional[float] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
 | 
			
		||||
@ -225,19 +218,31 @@ class ModelArguments:
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
 | 
			
		||||
    )
 | 
			
		||||
    compute_dtype: Optional[torch.dtype] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        init=False,
 | 
			
		||||
        metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
 | 
			
		||||
    )
 | 
			
		||||
    device_map: Optional[Union[str, Dict[str, Any]]] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        init=False,
 | 
			
		||||
        metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
 | 
			
		||||
    )
 | 
			
		||||
    model_max_length: Optional[int] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        init=False,
 | 
			
		||||
        metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."},
 | 
			
		||||
    )
 | 
			
		||||
    block_diag_attn: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        init=False,
 | 
			
		||||
        metadata={"help": "Whether use block diag attention or not, derived from `neat_packing`. Do not specify it."},
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        self.compute_dtype: Optional["torch.dtype"] = None
 | 
			
		||||
        self.device_map: Optional[Union[str, Dict[str, Any]]] = None
 | 
			
		||||
        self.model_max_length: Optional[int] = None
 | 
			
		||||
        self.block_diag_attn: bool = False
 | 
			
		||||
 | 
			
		||||
        if self.split_special_tokens and self.use_fast_tokenizer:
 | 
			
		||||
            raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
 | 
			
		||||
 | 
			
		||||
        if self.visual_inputs and self.use_unsloth:
 | 
			
		||||
            raise ValueError("Unsloth does not support MLLM yet. Stay tuned.")
 | 
			
		||||
 | 
			
		||||
        if self.adapter_name_or_path is not None:  # support merging multiple lora weights
 | 
			
		||||
            self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -257,9 +257,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
 | 
			
		||||
    if model_args.infer_backend == "vllm":
 | 
			
		||||
        raise ValueError("vLLM backend is only available for API, CLI and Web.")
 | 
			
		||||
 | 
			
		||||
    if model_args.visual_inputs and data_args.packing:
 | 
			
		||||
        raise ValueError("Cannot use packing in MLLM fine-tuning.")
 | 
			
		||||
 | 
			
		||||
    if model_args.use_unsloth and is_deepspeed_zero3_enabled():
 | 
			
		||||
        raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
 | 
			
		||||
 | 
			
		||||
@ -388,9 +385,6 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
 | 
			
		||||
        if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
 | 
			
		||||
            raise ValueError("vLLM only accepts a single adapter. Merge them first.")
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.stage == "rm" and model_args.visual_inputs:
 | 
			
		||||
        raise ValueError("Reward server does not support MLLM yet. Stay tuned.")
 | 
			
		||||
 | 
			
		||||
    _verify_model_args(model_args, data_args, finetuning_args)
 | 
			
		||||
    _check_extra_dependencies(model_args, finetuning_args)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -24,6 +24,7 @@ from ..extras.logging import get_logger
 | 
			
		||||
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
 | 
			
		||||
from .model_utils.quantization import QuantizationMethod
 | 
			
		||||
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
 | 
			
		||||
from .model_utils.visual import get_forbidden_modules, patch_target_modules
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
@ -37,7 +38,6 @@ logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
def _setup_full_tuning(
 | 
			
		||||
    model: "PreTrainedModel",
 | 
			
		||||
    model_args: "ModelArguments",
 | 
			
		||||
    finetuning_args: "FinetuningArguments",
 | 
			
		||||
    is_trainable: bool,
 | 
			
		||||
    cast_trainable_params_to_fp32: bool,
 | 
			
		||||
@ -46,13 +46,7 @@ def _setup_full_tuning(
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    logger.info("Fine-tuning method: Full")
 | 
			
		||||
    forbidden_modules = set()
 | 
			
		||||
    if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
 | 
			
		||||
        forbidden_modules.add("vision_tower")
 | 
			
		||||
 | 
			
		||||
    if model_args.visual_inputs and finetuning_args.train_mm_proj_only:
 | 
			
		||||
        forbidden_modules.add("language_model")
 | 
			
		||||
 | 
			
		||||
    forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
 | 
			
		||||
    for name, param in model.named_parameters():
 | 
			
		||||
        if not any(forbidden_module in name for forbidden_module in forbidden_modules):
 | 
			
		||||
            if cast_trainable_params_to_fp32:
 | 
			
		||||
@ -63,7 +57,6 @@ def _setup_full_tuning(
 | 
			
		||||
 | 
			
		||||
def _setup_freeze_tuning(
 | 
			
		||||
    model: "PreTrainedModel",
 | 
			
		||||
    model_args: "ModelArguments",
 | 
			
		||||
    finetuning_args: "FinetuningArguments",
 | 
			
		||||
    is_trainable: bool,
 | 
			
		||||
    cast_trainable_params_to_fp32: bool,
 | 
			
		||||
@ -72,8 +65,8 @@ def _setup_freeze_tuning(
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    logger.info("Fine-tuning method: Freeze")
 | 
			
		||||
    if model_args.visual_inputs:
 | 
			
		||||
        config = model.config.text_config
 | 
			
		||||
    if hasattr(model.config, "text_config"):  # composite models
 | 
			
		||||
        config = getattr(model.config, "text_config")
 | 
			
		||||
    else:
 | 
			
		||||
        config = model.config
 | 
			
		||||
 | 
			
		||||
@ -130,10 +123,7 @@ def _setup_freeze_tuning(
 | 
			
		||||
 | 
			
		||||
            trainable_layers.append(module_name)
 | 
			
		||||
 | 
			
		||||
    forbidden_modules = set()
 | 
			
		||||
    if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
 | 
			
		||||
        forbidden_modules.add("vision_tower")
 | 
			
		||||
 | 
			
		||||
    forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
 | 
			
		||||
    for name, param in model.named_parameters():
 | 
			
		||||
        if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
 | 
			
		||||
            forbidden_module in name for forbidden_module in forbidden_modules
 | 
			
		||||
@ -211,8 +201,7 @@ def _setup_lora_tuning(
 | 
			
		||||
        if finetuning_args.use_llama_pro:
 | 
			
		||||
            target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
 | 
			
		||||
 | 
			
		||||
        if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
 | 
			
		||||
            target_modules = "^(?!.*(?:vision_tower|visual)).*(?:{}).*".format("|".join(target_modules))
 | 
			
		||||
        target_modules = patch_target_modules(model.config, finetuning_args, target_modules)
 | 
			
		||||
 | 
			
		||||
        if (
 | 
			
		||||
            finetuning_args.use_dora
 | 
			
		||||
@ -303,9 +292,9 @@ def init_adapter(
 | 
			
		||||
        cast_trainable_params_to_fp32 = True
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.finetuning_type == "full":
 | 
			
		||||
        _setup_full_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
 | 
			
		||||
        _setup_full_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
 | 
			
		||||
    elif finetuning_args.finetuning_type == "freeze":
 | 
			
		||||
        _setup_freeze_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
 | 
			
		||||
        _setup_freeze_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
 | 
			
		||||
    elif finetuning_args.finetuning_type == "lora":
 | 
			
		||||
        model = _setup_lora_tuning(
 | 
			
		||||
            config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
 | 
			
		||||
 | 
			
		||||
@ -93,17 +93,10 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
 | 
			
		||||
 | 
			
		||||
    patch_tokenizer(tokenizer)
 | 
			
		||||
 | 
			
		||||
    if model_args.visual_inputs:
 | 
			
		||||
        try:
 | 
			
		||||
            processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
 | 
			
		||||
            setattr(processor, "tokenizer", tokenizer)
 | 
			
		||||
        except Exception:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "This multimodal LLM is not supported.\n"
 | 
			
		||||
                "Download LLaVA-1.5 models from: https://huggingface.co/llava-hf\n"
 | 
			
		||||
                "Download Yi-VL models from: https://huggingface.co/BUAADreamer"
 | 
			
		||||
            )
 | 
			
		||||
    else:
 | 
			
		||||
    try:
 | 
			
		||||
        processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
 | 
			
		||||
        setattr(processor, "tokenizer", tokenizer)
 | 
			
		||||
    except Exception:
 | 
			
		||||
        processor = None
 | 
			
		||||
 | 
			
		||||
    return {"tokenizer": tokenizer, "processor": processor}
 | 
			
		||||
@ -145,12 +138,16 @@ def load_model(
 | 
			
		||||
 | 
			
		||||
        if model_args.mixture_of_depths == "load":
 | 
			
		||||
            model = load_mod_pretrained_model(**init_kwargs)
 | 
			
		||||
        elif model_args.visual_inputs:
 | 
			
		||||
            model = AutoModelForVision2Seq.from_pretrained(**init_kwargs)
 | 
			
		||||
        elif model_args.train_from_scratch:
 | 
			
		||||
            model = AutoModelForCausalLM.from_config(config)
 | 
			
		||||
        else:
 | 
			
		||||
            model = AutoModelForCausalLM.from_pretrained(**init_kwargs)
 | 
			
		||||
            if type(config) in AutoModelForVision2Seq._model_mapping.keys():  # assume built-in models
 | 
			
		||||
                load_class = AutoModelForVision2Seq
 | 
			
		||||
            else:
 | 
			
		||||
                load_class = AutoModelForCausalLM
 | 
			
		||||
 | 
			
		||||
            if model_args.train_from_scratch:
 | 
			
		||||
                model = load_class.from_config(config)
 | 
			
		||||
            else:
 | 
			
		||||
                model = load_class.from_pretrained(**init_kwargs)
 | 
			
		||||
 | 
			
		||||
        if model_args.mixture_of_depths == "convert":
 | 
			
		||||
            model = convert_pretrained_model_to_mod(model, config, model_args)
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from typing import TYPE_CHECKING, Tuple
 | 
			
		||||
from typing import TYPE_CHECKING, List, Sequence, Set, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import transformers.models
 | 
			
		||||
@ -28,7 +28,7 @@ from ...extras.logging import get_logger
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
 | 
			
		||||
 | 
			
		||||
    from ...hparams import ModelArguments
 | 
			
		||||
    from ...hparams import FinetuningArguments, ModelArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
@ -80,24 +80,74 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
 | 
			
		||||
        self.act = ACT2FN[projector_hidden_act]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def autocast_projector_dtype(
 | 
			
		||||
    model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector"
 | 
			
		||||
) -> None:
 | 
			
		||||
def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
 | 
			
		||||
    r"""
 | 
			
		||||
    Casts projector output to half precision for quantized VLMs.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def _mm_projector_forward_post_hook(
 | 
			
		||||
        module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
 | 
			
		||||
    ) -> "torch.Tensor":
 | 
			
		||||
        return output.to(model_args.compute_dtype)
 | 
			
		||||
 | 
			
		||||
    if hasattr(model, mm_projector_name) and getattr(model, "quantization_method", None):
 | 
			
		||||
    if getattr(model, "quantization_method", None):
 | 
			
		||||
        if getattr(model.config, "model_type", None) in ["llava", "paligemma"]:
 | 
			
		||||
            mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector")
 | 
			
		||||
        elif getattr(model.config, "model_type", None) == "qwen2_vl":
 | 
			
		||||
            mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger")
 | 
			
		||||
        else:
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
 | 
			
		||||
        mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
 | 
			
		||||
        mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def configure_visual_model(config: "PretrainedConfig") -> None:
 | 
			
		||||
    r"""
 | 
			
		||||
    Patches VLMs before loading them.
 | 
			
		||||
    """
 | 
			
		||||
    if getattr(config, "model_type", None) == "llava":  # required for ds zero3 and valuehead models
 | 
			
		||||
        setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
 | 
			
		||||
 | 
			
		||||
    if getattr(config, "is_yi_vl_derived_model", None):
 | 
			
		||||
        logger.info("Detected Yi-VL model, applying projector patch.")
 | 
			
		||||
        transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> Set[str]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Freezes vision tower and language model for VLM full/freeze tuning.
 | 
			
		||||
    """
 | 
			
		||||
    forbidden_modules = set()
 | 
			
		||||
    if getattr(config, "model_type", None) in ["llava", "paligemma"]:
 | 
			
		||||
        if finetuning_args.freeze_vision_tower:
 | 
			
		||||
            forbidden_modules.add("vision_tower")
 | 
			
		||||
 | 
			
		||||
        if finetuning_args.train_mm_proj_only:
 | 
			
		||||
            forbidden_modules.add("language_model")
 | 
			
		||||
 | 
			
		||||
    elif getattr(config, "model_type", None) == "qwen2_vl":
 | 
			
		||||
        if finetuning_args.freeze_vision_tower:
 | 
			
		||||
            forbidden_modules.add("visual")
 | 
			
		||||
 | 
			
		||||
        if finetuning_args.train_mm_proj_only:
 | 
			
		||||
            raise ValueError("Qwen2-VL models do not support `train_mm_proj_only`.")
 | 
			
		||||
 | 
			
		||||
    return forbidden_modules
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def patch_target_modules(
 | 
			
		||||
    config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
 | 
			
		||||
) -> Union[str, List[str]]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Freezes vision tower for VLM LoRA tuning.
 | 
			
		||||
    """
 | 
			
		||||
    if not finetuning_args.freeze_vision_tower:
 | 
			
		||||
        return target_modules
 | 
			
		||||
 | 
			
		||||
    if getattr(config, "model_type", None) in ["llava", "paligemma"]:
 | 
			
		||||
        return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
 | 
			
		||||
    elif getattr(config, "model_type", None) == "qwen2_vl":
 | 
			
		||||
        return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
 | 
			
		||||
    else:
 | 
			
		||||
        return target_modules
 | 
			
		||||
 | 
			
		||||
@ -131,11 +131,9 @@ def patch_model(
 | 
			
		||||
    if model_args.resize_vocab:
 | 
			
		||||
        resize_embedding_layer(model, tokenizer)
 | 
			
		||||
 | 
			
		||||
    if model_args.visual_inputs:
 | 
			
		||||
        autocast_projector_dtype(model, model_args)
 | 
			
		||||
 | 
			
		||||
    if is_trainable:
 | 
			
		||||
        prepare_model_for_training(model, model_args)
 | 
			
		||||
        autocast_projector_dtype(model, model_args)
 | 
			
		||||
        add_z3_leaf_module(model)
 | 
			
		||||
 | 
			
		||||
    if not model_args.use_unsloth:
 | 
			
		||||
 | 
			
		||||
@ -61,7 +61,6 @@ def run_sft(
 | 
			
		||||
    # Override the decoding parameters of Seq2SeqTrainer
 | 
			
		||||
    training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
 | 
			
		||||
    training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
 | 
			
		||||
    training_args.remove_unused_columns = False if model_args.visual_inputs else training_args.remove_unused_columns
 | 
			
		||||
 | 
			
		||||
    # Metric utils
 | 
			
		||||
    metric_module = {}
 | 
			
		||||
 | 
			
		||||
@ -132,7 +132,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
 | 
			
		||||
        if model_args.export_hub_model_id is not None:
 | 
			
		||||
            tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
 | 
			
		||||
 | 
			
		||||
        if model_args.visual_inputs and processor is not None:
 | 
			
		||||
        if processor is not None:
 | 
			
		||||
            getattr(processor, "image_processor").save_pretrained(model_args.export_dir)
 | 
			
		||||
            if model_args.export_hub_model_id is not None:
 | 
			
		||||
                getattr(processor, "image_processor").push_to_hub(
 | 
			
		||||
 | 
			
		||||
@ -90,7 +90,6 @@ class WebChatModel(ChatModel):
 | 
			
		||||
            template=get("top.template"),
 | 
			
		||||
            flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
 | 
			
		||||
            use_unsloth=(get("top.booster") == "unsloth"),
 | 
			
		||||
            visual_inputs=get("top.visual_inputs"),
 | 
			
		||||
            rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
 | 
			
		||||
            infer_backend=get("infer.infer_backend"),
 | 
			
		||||
            infer_dtype=get("infer.infer_dtype"),
 | 
			
		||||
 | 
			
		||||
@ -122,16 +122,15 @@ def get_prefix(model_name: str) -> str:
 | 
			
		||||
    return model_name.split("-")[0]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_model_info(model_name: str) -> Tuple[str, str, bool]:
 | 
			
		||||
def get_model_info(model_name: str) -> Tuple[str, str]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Gets the necessary information of this model.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        model_path (str)
 | 
			
		||||
        template (str)
 | 
			
		||||
        visual (bool)
 | 
			
		||||
    """
 | 
			
		||||
    return get_model_path(model_name), get_template(model_name), get_visual(model_name)
 | 
			
		||||
    return get_model_path(model_name), get_template(model_name)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_template(model_name: str) -> str:
 | 
			
		||||
 | 
			
		||||
@ -15,6 +15,7 @@
 | 
			
		||||
from typing import TYPE_CHECKING, Dict
 | 
			
		||||
 | 
			
		||||
from ...extras.packages import is_gradio_available
 | 
			
		||||
from ..common import get_visual
 | 
			
		||||
from .chatbot import create_chat_box
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -64,9 +65,9 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
 | 
			
		||||
        lambda: ([], []), outputs=[chatbot, messages]
 | 
			
		||||
    ).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]])
 | 
			
		||||
 | 
			
		||||
    engine.manager.get_elem_by_id("top.visual_inputs").change(
 | 
			
		||||
        lambda enabled: gr.Column(visible=enabled),
 | 
			
		||||
        [engine.manager.get_elem_by_id("top.visual_inputs")],
 | 
			
		||||
    engine.manager.get_elem_by_id("top.model_name").change(
 | 
			
		||||
        lambda model_name: gr.Column(visible=get_visual(model_name)),
 | 
			
		||||
        [engine.manager.get_elem_by_id("top.model_name")],
 | 
			
		||||
        [chat_elems["image_box"]],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -48,9 +48,8 @@ def create_top() -> Dict[str, "Component"]:
 | 
			
		||||
            template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=1)
 | 
			
		||||
            rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=2)
 | 
			
		||||
            booster = gr.Radio(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto", scale=3)
 | 
			
		||||
            visual_inputs = gr.Checkbox(scale=1)
 | 
			
		||||
 | 
			
		||||
    model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False).then(
 | 
			
		||||
    model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then(
 | 
			
		||||
        list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
 | 
			
		||||
    )
 | 
			
		||||
    model_name.input(save_config, inputs=[lang, model_name], queue=False)
 | 
			
		||||
@ -73,5 +72,4 @@ def create_top() -> Dict[str, "Component"]:
 | 
			
		||||
        template=template,
 | 
			
		||||
        rope_scaling=rope_scaling,
 | 
			
		||||
        booster=booster,
 | 
			
		||||
        visual_inputs=visual_inputs,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -183,20 +183,6 @@ LOCALES = {
 | 
			
		||||
            "label": "부스터",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "visual_inputs": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Visual inputs",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "визуальные входы",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "图像输入",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "label": "시각적 입력",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "training_stage": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Stage",
 | 
			
		||||
 | 
			
		||||
@ -75,5 +75,4 @@ class Manager:
 | 
			
		||||
            self._id_to_elem["top.template"],
 | 
			
		||||
            self._id_to_elem["top.rope_scaling"],
 | 
			
		||||
            self._id_to_elem["top.booster"],
 | 
			
		||||
            self._id_to_elem["top.visual_inputs"],
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
@ -116,7 +116,6 @@ class Runner:
 | 
			
		||||
            flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
 | 
			
		||||
            use_unsloth=(get("top.booster") == "unsloth"),
 | 
			
		||||
            enable_liger_kernel=(get("top.booster") == "liger_kernel"),
 | 
			
		||||
            visual_inputs=get("top.visual_inputs"),
 | 
			
		||||
            dataset_dir=get("train.dataset_dir"),
 | 
			
		||||
            dataset=",".join(get("train.dataset")),
 | 
			
		||||
            cutoff_len=get("train.cutoff_len"),
 | 
			
		||||
@ -252,7 +251,6 @@ class Runner:
 | 
			
		||||
            rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
 | 
			
		||||
            flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
 | 
			
		||||
            use_unsloth=(get("top.booster") == "unsloth"),
 | 
			
		||||
            visual_inputs=get("top.visual_inputs"),
 | 
			
		||||
            dataset_dir=get("eval.dataset_dir"),
 | 
			
		||||
            eval_dataset=",".join(get("eval.dataset")),
 | 
			
		||||
            cutoff_len=get("eval.cutoff_len"),
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user