diff --git a/examples/inference/llava1_5.yaml b/examples/inference/llava1_5.yaml index 4a7673e1..68f3b8ff 100644 --- a/examples/inference/llava1_5.yaml +++ b/examples/inference/llava1_5.yaml @@ -1,3 +1,2 @@ model_name_or_path: llava-hf/llava-1.5-7b-hf template: llava -visual_inputs: true diff --git a/examples/inference/qwen2_vl.yaml b/examples/inference/qwen2_vl.yaml index a875f0d2..ed1cef6c 100644 --- a/examples/inference/qwen2_vl.yaml +++ b/examples/inference/qwen2_vl.yaml @@ -1,3 +1,2 @@ model_name_or_path: Qwen/Qwen2-VL-7B-Instruct template: qwen2_vl -visual_inputs: true diff --git a/examples/merge_lora/qwen2vl_lora_sft.yaml b/examples/merge_lora/qwen2vl_lora_sft.yaml index c71cd87e..f97e2766 100644 --- a/examples/merge_lora/qwen2vl_lora_sft.yaml +++ b/examples/merge_lora/qwen2vl_lora_sft.yaml @@ -3,7 +3,6 @@ ### model model_name_or_path: Qwen/Qwen2-VL-7B-Instruct adapter_name_or_path: saves/qwen2_vl-7b/lora/sft -visual_inputs: true template: qwen2_vl finetuning_type: lora diff --git a/examples/train_full/qwen2vl_full_sft.yaml b/examples/train_full/qwen2vl_full_sft.yaml index 1163a37e..a6cd40fb 100644 --- a/examples/train_full/qwen2vl_full_sft.yaml +++ b/examples/train_full/qwen2vl_full_sft.yaml @@ -1,6 +1,5 @@ ### model model_name_or_path: Qwen/Qwen2-VL-7B-Instruct -visual_inputs: true ### method stage: sft @@ -9,7 +8,7 @@ finetuning_type: full deepspeed: examples/deepspeed/ds_z3_config.json ### dataset -dataset: mllm_demo +dataset: mllm_demo,identity template: qwen2_vl cutoff_len: 1024 max_samples: 1000 diff --git a/examples/train_lora/llava1_5_lora_sft.yaml b/examples/train_lora/llava1_5_lora_sft.yaml index f0616ac8..00a2ebc0 100644 --- a/examples/train_lora/llava1_5_lora_sft.yaml +++ b/examples/train_lora/llava1_5_lora_sft.yaml @@ -1,6 +1,5 @@ ### model model_name_or_path: llava-hf/llava-1.5-7b-hf -visual_inputs: true ### method stage: sft diff --git a/examples/train_lora/qwen2vl_lora_sft.yaml b/examples/train_lora/qwen2vl_lora_sft.yaml index 74b58922..597db248 100644 --- a/examples/train_lora/qwen2vl_lora_sft.yaml +++ b/examples/train_lora/qwen2vl_lora_sft.yaml @@ -1,6 +1,5 @@ ### model model_name_or_path: Qwen/Qwen2-VL-7B-Instruct -visual_inputs: true ### method stage: sft @@ -9,7 +8,7 @@ finetuning_type: lora lora_target: all ### dataset -dataset: mllm_demo +dataset: mllm_demo,identity template: qwen2_vl cutoff_len: 1024 max_samples: 1000 diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index d64f4d25..26f2896e 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -86,7 +86,7 @@ class VllmEngine(BaseEngine): "max_lora_rank": model_args.vllm_max_lora_rank, } - if model_args.visual_inputs: + if getattr(config, "model_type", None) == "llava": image_size = config.vision_config.image_size patch_size = config.vision_config.patch_size self.image_feature_size = (image_size // patch_size) ** 2 diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index b25e31e0..65f0fa62 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -16,15 +16,12 @@ # limitations under the License. from dataclasses import asdict, dataclass, field -from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union +from typing import Any, Dict, Literal, Optional, Union +import torch from typing_extensions import Self -if TYPE_CHECKING: - import torch - - @dataclass class ModelArguments: r""" @@ -121,10 +118,6 @@ class ModelArguments: default=False, metadata={"help": "Whether or not to enable liger kernel for faster training."}, ) - visual_inputs: bool = field( - default=False, - metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."}, - ) moe_aux_loss_coef: Optional[float] = field( default=None, metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."}, @@ -225,19 +218,31 @@ class ModelArguments: default=False, metadata={"help": "For debugging purposes, print the status of the parameters in the model."}, ) + compute_dtype: Optional[torch.dtype] = field( + default=None, + init=False, + metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."}, + ) + device_map: Optional[Union[str, Dict[str, Any]]] = field( + default=None, + init=False, + metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."}, + ) + model_max_length: Optional[int] = field( + default=None, + init=False, + metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."}, + ) + block_diag_attn: bool = field( + default=False, + init=False, + metadata={"help": "Whether use block diag attention or not, derived from `neat_packing`. Do not specify it."}, + ) def __post_init__(self): - self.compute_dtype: Optional["torch.dtype"] = None - self.device_map: Optional[Union[str, Dict[str, Any]]] = None - self.model_max_length: Optional[int] = None - self.block_diag_attn: bool = False - if self.split_special_tokens and self.use_fast_tokenizer: raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") - if self.visual_inputs and self.use_unsloth: - raise ValueError("Unsloth does not support MLLM yet. Stay tuned.") - if self.adapter_name_or_path is not None: # support merging multiple lora weights self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")] diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 7e3e39cd..a7dfb0bd 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -257,9 +257,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if model_args.infer_backend == "vllm": raise ValueError("vLLM backend is only available for API, CLI and Web.") - if model_args.visual_inputs and data_args.packing: - raise ValueError("Cannot use packing in MLLM fine-tuning.") - if model_args.use_unsloth and is_deepspeed_zero3_enabled(): raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.") @@ -388,9 +385,6 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: raise ValueError("vLLM only accepts a single adapter. Merge them first.") - if finetuning_args.stage == "rm" and model_args.visual_inputs: - raise ValueError("Reward server does not support MLLM yet. Stay tuned.") - _verify_model_args(model_args, data_args, finetuning_args) _check_extra_dependencies(model_args, finetuning_args) diff --git a/src/llamafactory/model/adapter.py b/src/llamafactory/model/adapter.py index f18bcbc9..25b0eab8 100644 --- a/src/llamafactory/model/adapter.py +++ b/src/llamafactory/model/adapter.py @@ -24,6 +24,7 @@ from ..extras.logging import get_logger from .model_utils.misc import find_all_linear_modules, find_expanded_modules from .model_utils.quantization import QuantizationMethod from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model +from .model_utils.visual import get_forbidden_modules, patch_target_modules if TYPE_CHECKING: @@ -37,7 +38,6 @@ logger = get_logger(__name__) def _setup_full_tuning( model: "PreTrainedModel", - model_args: "ModelArguments", finetuning_args: "FinetuningArguments", is_trainable: bool, cast_trainable_params_to_fp32: bool, @@ -46,13 +46,7 @@ def _setup_full_tuning( return logger.info("Fine-tuning method: Full") - forbidden_modules = set() - if model_args.visual_inputs and finetuning_args.freeze_vision_tower: - forbidden_modules.add("vision_tower") - - if model_args.visual_inputs and finetuning_args.train_mm_proj_only: - forbidden_modules.add("language_model") - + forbidden_modules = get_forbidden_modules(model.config, finetuning_args) for name, param in model.named_parameters(): if not any(forbidden_module in name for forbidden_module in forbidden_modules): if cast_trainable_params_to_fp32: @@ -63,7 +57,6 @@ def _setup_full_tuning( def _setup_freeze_tuning( model: "PreTrainedModel", - model_args: "ModelArguments", finetuning_args: "FinetuningArguments", is_trainable: bool, cast_trainable_params_to_fp32: bool, @@ -72,8 +65,8 @@ def _setup_freeze_tuning( return logger.info("Fine-tuning method: Freeze") - if model_args.visual_inputs: - config = model.config.text_config + if hasattr(model.config, "text_config"): # composite models + config = getattr(model.config, "text_config") else: config = model.config @@ -130,10 +123,7 @@ def _setup_freeze_tuning( trainable_layers.append(module_name) - forbidden_modules = set() - if model_args.visual_inputs and finetuning_args.freeze_vision_tower: - forbidden_modules.add("vision_tower") - + forbidden_modules = get_forbidden_modules(model.config, finetuning_args) for name, param in model.named_parameters(): if any(trainable_layer in name for trainable_layer in trainable_layers) and not any( forbidden_module in name for forbidden_module in forbidden_modules @@ -211,8 +201,7 @@ def _setup_lora_tuning( if finetuning_args.use_llama_pro: target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers) - if model_args.visual_inputs and finetuning_args.freeze_vision_tower: - target_modules = "^(?!.*(?:vision_tower|visual)).*(?:{}).*".format("|".join(target_modules)) + target_modules = patch_target_modules(model.config, finetuning_args, target_modules) if ( finetuning_args.use_dora @@ -303,9 +292,9 @@ def init_adapter( cast_trainable_params_to_fp32 = True if finetuning_args.finetuning_type == "full": - _setup_full_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32) + _setup_full_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32) elif finetuning_args.finetuning_type == "freeze": - _setup_freeze_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32) + _setup_freeze_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32) elif finetuning_args.finetuning_type == "lora": model = _setup_lora_tuning( config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32 diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index fe700d53..9a16c0ce 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -93,17 +93,10 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": patch_tokenizer(tokenizer) - if model_args.visual_inputs: - try: - processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) - setattr(processor, "tokenizer", tokenizer) - except Exception: - raise ValueError( - "This multimodal LLM is not supported.\n" - "Download LLaVA-1.5 models from: https://huggingface.co/llava-hf\n" - "Download Yi-VL models from: https://huggingface.co/BUAADreamer" - ) - else: + try: + processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) + setattr(processor, "tokenizer", tokenizer) + except Exception: processor = None return {"tokenizer": tokenizer, "processor": processor} @@ -145,12 +138,16 @@ def load_model( if model_args.mixture_of_depths == "load": model = load_mod_pretrained_model(**init_kwargs) - elif model_args.visual_inputs: - model = AutoModelForVision2Seq.from_pretrained(**init_kwargs) - elif model_args.train_from_scratch: - model = AutoModelForCausalLM.from_config(config) else: - model = AutoModelForCausalLM.from_pretrained(**init_kwargs) + if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # assume built-in models + load_class = AutoModelForVision2Seq + else: + load_class = AutoModelForCausalLM + + if model_args.train_from_scratch: + model = load_class.from_config(config) + else: + model = load_class.from_pretrained(**init_kwargs) if model_args.mixture_of_depths == "convert": model = convert_pretrained_model_to_mod(model, config, model_args) diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 828a5e6d..b3103a2c 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, List, Sequence, Set, Tuple, Union import torch import transformers.models @@ -28,7 +28,7 @@ from ...extras.logging import get_logger if TYPE_CHECKING: from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel - from ...hparams import ModelArguments + from ...hparams import FinetuningArguments, ModelArguments logger = get_logger(__name__) @@ -80,24 +80,74 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL): self.act = ACT2FN[projector_hidden_act] -def autocast_projector_dtype( - model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector" -) -> None: +def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None: + r""" + Casts projector output to half precision for quantized VLMs. + """ + def _mm_projector_forward_post_hook( module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor" ) -> "torch.Tensor": return output.to(model_args.compute_dtype) - if hasattr(model, mm_projector_name) and getattr(model, "quantization_method", None): + if getattr(model, "quantization_method", None): + if getattr(model.config, "model_type", None) in ["llava", "paligemma"]: + mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector") + elif getattr(model.config, "model_type", None) == "qwen2_vl": + mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger") + else: + return + logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype)) - mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name) mm_projector.register_forward_hook(_mm_projector_forward_post_hook) def configure_visual_model(config: "PretrainedConfig") -> None: + r""" + Patches VLMs before loading them. + """ if getattr(config, "model_type", None) == "llava": # required for ds zero3 and valuehead models setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None)) if getattr(config, "is_yi_vl_derived_model", None): logger.info("Detected Yi-VL model, applying projector patch.") transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL + + +def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> Set[str]: + r""" + Freezes vision tower and language model for VLM full/freeze tuning. + """ + forbidden_modules = set() + if getattr(config, "model_type", None) in ["llava", "paligemma"]: + if finetuning_args.freeze_vision_tower: + forbidden_modules.add("vision_tower") + + if finetuning_args.train_mm_proj_only: + forbidden_modules.add("language_model") + + elif getattr(config, "model_type", None) == "qwen2_vl": + if finetuning_args.freeze_vision_tower: + forbidden_modules.add("visual") + + if finetuning_args.train_mm_proj_only: + raise ValueError("Qwen2-VL models do not support `train_mm_proj_only`.") + + return forbidden_modules + + +def patch_target_modules( + config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str] +) -> Union[str, List[str]]: + r""" + Freezes vision tower for VLM LoRA tuning. + """ + if not finetuning_args.freeze_vision_tower: + return target_modules + + if getattr(config, "model_type", None) in ["llava", "paligemma"]: + return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules)) + elif getattr(config, "model_type", None) == "qwen2_vl": + return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules)) + else: + return target_modules diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index a278c154..3de82703 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -131,11 +131,9 @@ def patch_model( if model_args.resize_vocab: resize_embedding_layer(model, tokenizer) - if model_args.visual_inputs: - autocast_projector_dtype(model, model_args) - if is_trainable: prepare_model_for_training(model, model_args) + autocast_projector_dtype(model, model_args) add_z3_leaf_module(model) if not model_args.use_unsloth: diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index 5da99557..5e3787f1 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -61,7 +61,6 @@ def run_sft( # Override the decoding parameters of Seq2SeqTrainer training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams - training_args.remove_unused_columns = False if model_args.visual_inputs else training_args.remove_unused_columns # Metric utils metric_module = {} diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index cb55900f..e0d7a7c9 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -132,7 +132,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None: if model_args.export_hub_model_id is not None: tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token) - if model_args.visual_inputs and processor is not None: + if processor is not None: getattr(processor, "image_processor").save_pretrained(model_args.export_dir) if model_args.export_hub_model_id is not None: getattr(processor, "image_processor").push_to_hub( diff --git a/src/llamafactory/webui/chatter.py b/src/llamafactory/webui/chatter.py index 8abef920..0b8267ed 100644 --- a/src/llamafactory/webui/chatter.py +++ b/src/llamafactory/webui/chatter.py @@ -90,7 +90,6 @@ class WebChatModel(ChatModel): template=get("top.template"), flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", use_unsloth=(get("top.booster") == "unsloth"), - visual_inputs=get("top.visual_inputs"), rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, infer_backend=get("infer.infer_backend"), infer_dtype=get("infer.infer_dtype"), diff --git a/src/llamafactory/webui/common.py b/src/llamafactory/webui/common.py index e83cadd9..019812c7 100644 --- a/src/llamafactory/webui/common.py +++ b/src/llamafactory/webui/common.py @@ -122,16 +122,15 @@ def get_prefix(model_name: str) -> str: return model_name.split("-")[0] -def get_model_info(model_name: str) -> Tuple[str, str, bool]: +def get_model_info(model_name: str) -> Tuple[str, str]: r""" Gets the necessary information of this model. Returns: model_path (str) template (str) - visual (bool) """ - return get_model_path(model_name), get_template(model_name), get_visual(model_name) + return get_model_path(model_name), get_template(model_name) def get_template(model_name: str) -> str: diff --git a/src/llamafactory/webui/components/infer.py b/src/llamafactory/webui/components/infer.py index a0064479..9bbdc842 100644 --- a/src/llamafactory/webui/components/infer.py +++ b/src/llamafactory/webui/components/infer.py @@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, Dict from ...extras.packages import is_gradio_available +from ..common import get_visual from .chatbot import create_chat_box @@ -64,9 +65,9 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]: lambda: ([], []), outputs=[chatbot, messages] ).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]]) - engine.manager.get_elem_by_id("top.visual_inputs").change( - lambda enabled: gr.Column(visible=enabled), - [engine.manager.get_elem_by_id("top.visual_inputs")], + engine.manager.get_elem_by_id("top.model_name").change( + lambda model_name: gr.Column(visible=get_visual(model_name)), + [engine.manager.get_elem_by_id("top.model_name")], [chat_elems["image_box"]], ) diff --git a/src/llamafactory/webui/components/top.py b/src/llamafactory/webui/components/top.py index f2630c7b..55096601 100644 --- a/src/llamafactory/webui/components/top.py +++ b/src/llamafactory/webui/components/top.py @@ -48,9 +48,8 @@ def create_top() -> Dict[str, "Component"]: template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=1) rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=2) booster = gr.Radio(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto", scale=3) - visual_inputs = gr.Checkbox(scale=1) - model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False).then( + model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then( list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False ) model_name.input(save_config, inputs=[lang, model_name], queue=False) @@ -73,5 +72,4 @@ def create_top() -> Dict[str, "Component"]: template=template, rope_scaling=rope_scaling, booster=booster, - visual_inputs=visual_inputs, ) diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index 0a8ca68a..f742c246 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -183,20 +183,6 @@ LOCALES = { "label": "부스터", }, }, - "visual_inputs": { - "en": { - "label": "Visual inputs", - }, - "ru": { - "label": "визуальные входы", - }, - "zh": { - "label": "图像输入", - }, - "ko": { - "label": "시각적 입력", - }, - }, "training_stage": { "en": { "label": "Stage", diff --git a/src/llamafactory/webui/manager.py b/src/llamafactory/webui/manager.py index ebe9f1b9..61c2f35a 100644 --- a/src/llamafactory/webui/manager.py +++ b/src/llamafactory/webui/manager.py @@ -75,5 +75,4 @@ class Manager: self._id_to_elem["top.template"], self._id_to_elem["top.rope_scaling"], self._id_to_elem["top.booster"], - self._id_to_elem["top.visual_inputs"], } diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index 67d910fa..55d15b41 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -116,7 +116,6 @@ class Runner: flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", use_unsloth=(get("top.booster") == "unsloth"), enable_liger_kernel=(get("top.booster") == "liger_kernel"), - visual_inputs=get("top.visual_inputs"), dataset_dir=get("train.dataset_dir"), dataset=",".join(get("train.dataset")), cutoff_len=get("train.cutoff_len"), @@ -252,7 +251,6 @@ class Runner: rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", use_unsloth=(get("top.booster") == "unsloth"), - visual_inputs=get("top.visual_inputs"), dataset_dir=get("eval.dataset_dir"), eval_dataset=",".join(get("eval.dataset")), cutoff_len=get("eval.cutoff_len"),