mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 03:02:51 +08:00
remove visual_inputs, fix qlora
Former-commit-id: a025c3df61db154bef13033518903bbf846f4fc8
This commit is contained in:
parent
51a0016873
commit
f31e7e0dfc
@ -1,3 +1,2 @@
|
||||
model_name_or_path: llava-hf/llava-1.5-7b-hf
|
||||
template: llava
|
||||
visual_inputs: true
|
||||
|
@ -1,3 +1,2 @@
|
||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
||||
template: qwen2_vl
|
||||
visual_inputs: true
|
||||
|
@ -3,7 +3,6 @@
|
||||
### model
|
||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
||||
adapter_name_or_path: saves/qwen2_vl-7b/lora/sft
|
||||
visual_inputs: true
|
||||
template: qwen2_vl
|
||||
finetuning_type: lora
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
||||
visual_inputs: true
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
@ -9,7 +8,7 @@ finetuning_type: full
|
||||
deepspeed: examples/deepspeed/ds_z3_config.json
|
||||
|
||||
### dataset
|
||||
dataset: mllm_demo
|
||||
dataset: mllm_demo,identity
|
||||
template: qwen2_vl
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
@ -1,6 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: llava-hf/llava-1.5-7b-hf
|
||||
visual_inputs: true
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
|
@ -1,6 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
||||
visual_inputs: true
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
@ -9,7 +8,7 @@ finetuning_type: lora
|
||||
lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: mllm_demo
|
||||
dataset: mllm_demo,identity
|
||||
template: qwen2_vl
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
@ -86,7 +86,7 @@ class VllmEngine(BaseEngine):
|
||||
"max_lora_rank": model_args.vllm_max_lora_rank,
|
||||
}
|
||||
|
||||
if model_args.visual_inputs:
|
||||
if getattr(config, "model_type", None) == "llava":
|
||||
image_size = config.vision_config.image_size
|
||||
patch_size = config.vision_config.patch_size
|
||||
self.image_feature_size = (image_size // patch_size) ** 2
|
||||
|
@ -16,15 +16,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
r"""
|
||||
@ -121,10 +118,6 @@ class ModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to enable liger kernel for faster training."},
|
||||
)
|
||||
visual_inputs: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."},
|
||||
)
|
||||
moe_aux_loss_coef: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
|
||||
@ -225,19 +218,31 @@ class ModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
|
||||
)
|
||||
compute_dtype: Optional[torch.dtype] = field(
|
||||
default=None,
|
||||
init=False,
|
||||
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
|
||||
)
|
||||
device_map: Optional[Union[str, Dict[str, Any]]] = field(
|
||||
default=None,
|
||||
init=False,
|
||||
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
|
||||
)
|
||||
model_max_length: Optional[int] = field(
|
||||
default=None,
|
||||
init=False,
|
||||
metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."},
|
||||
)
|
||||
block_diag_attn: bool = field(
|
||||
default=False,
|
||||
init=False,
|
||||
metadata={"help": "Whether use block diag attention or not, derived from `neat_packing`. Do not specify it."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.compute_dtype: Optional["torch.dtype"] = None
|
||||
self.device_map: Optional[Union[str, Dict[str, Any]]] = None
|
||||
self.model_max_length: Optional[int] = None
|
||||
self.block_diag_attn: bool = False
|
||||
|
||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
||||
|
||||
if self.visual_inputs and self.use_unsloth:
|
||||
raise ValueError("Unsloth does not support MLLM yet. Stay tuned.")
|
||||
|
||||
if self.adapter_name_or_path is not None: # support merging multiple lora weights
|
||||
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
|
||||
|
||||
|
@ -257,9 +257,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
if model_args.infer_backend == "vllm":
|
||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||
|
||||
if model_args.visual_inputs and data_args.packing:
|
||||
raise ValueError("Cannot use packing in MLLM fine-tuning.")
|
||||
|
||||
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
|
||||
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
||||
|
||||
@ -388,9 +385,6 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
||||
raise ValueError("vLLM only accepts a single adapter. Merge them first.")
|
||||
|
||||
if finetuning_args.stage == "rm" and model_args.visual_inputs:
|
||||
raise ValueError("Reward server does not support MLLM yet. Stay tuned.")
|
||||
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args)
|
||||
|
||||
|
@ -24,6 +24,7 @@ from ..extras.logging import get_logger
|
||||
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
|
||||
from .model_utils.quantization import QuantizationMethod
|
||||
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
||||
from .model_utils.visual import get_forbidden_modules, patch_target_modules
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -37,7 +38,6 @@ logger = get_logger(__name__)
|
||||
|
||||
def _setup_full_tuning(
|
||||
model: "PreTrainedModel",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: bool,
|
||||
cast_trainable_params_to_fp32: bool,
|
||||
@ -46,13 +46,7 @@ def _setup_full_tuning(
|
||||
return
|
||||
|
||||
logger.info("Fine-tuning method: Full")
|
||||
forbidden_modules = set()
|
||||
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
||||
forbidden_modules.add("vision_tower")
|
||||
|
||||
if model_args.visual_inputs and finetuning_args.train_mm_proj_only:
|
||||
forbidden_modules.add("language_model")
|
||||
|
||||
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
|
||||
for name, param in model.named_parameters():
|
||||
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
|
||||
if cast_trainable_params_to_fp32:
|
||||
@ -63,7 +57,6 @@ def _setup_full_tuning(
|
||||
|
||||
def _setup_freeze_tuning(
|
||||
model: "PreTrainedModel",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: bool,
|
||||
cast_trainable_params_to_fp32: bool,
|
||||
@ -72,8 +65,8 @@ def _setup_freeze_tuning(
|
||||
return
|
||||
|
||||
logger.info("Fine-tuning method: Freeze")
|
||||
if model_args.visual_inputs:
|
||||
config = model.config.text_config
|
||||
if hasattr(model.config, "text_config"): # composite models
|
||||
config = getattr(model.config, "text_config")
|
||||
else:
|
||||
config = model.config
|
||||
|
||||
@ -130,10 +123,7 @@ def _setup_freeze_tuning(
|
||||
|
||||
trainable_layers.append(module_name)
|
||||
|
||||
forbidden_modules = set()
|
||||
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
||||
forbidden_modules.add("vision_tower")
|
||||
|
||||
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
|
||||
for name, param in model.named_parameters():
|
||||
if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
|
||||
forbidden_module in name for forbidden_module in forbidden_modules
|
||||
@ -211,8 +201,7 @@ def _setup_lora_tuning(
|
||||
if finetuning_args.use_llama_pro:
|
||||
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
|
||||
|
||||
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
||||
target_modules = "^(?!.*(?:vision_tower|visual)).*(?:{}).*".format("|".join(target_modules))
|
||||
target_modules = patch_target_modules(model.config, finetuning_args, target_modules)
|
||||
|
||||
if (
|
||||
finetuning_args.use_dora
|
||||
@ -303,9 +292,9 @@ def init_adapter(
|
||||
cast_trainable_params_to_fp32 = True
|
||||
|
||||
if finetuning_args.finetuning_type == "full":
|
||||
_setup_full_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
|
||||
_setup_full_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
|
||||
elif finetuning_args.finetuning_type == "freeze":
|
||||
_setup_freeze_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
|
||||
_setup_freeze_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
|
||||
elif finetuning_args.finetuning_type == "lora":
|
||||
model = _setup_lora_tuning(
|
||||
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
|
||||
|
@ -93,17 +93,10 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
|
||||
patch_tokenizer(tokenizer)
|
||||
|
||||
if model_args.visual_inputs:
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||
setattr(processor, "tokenizer", tokenizer)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
"This multimodal LLM is not supported.\n"
|
||||
"Download LLaVA-1.5 models from: https://huggingface.co/llava-hf\n"
|
||||
"Download Yi-VL models from: https://huggingface.co/BUAADreamer"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||
setattr(processor, "tokenizer", tokenizer)
|
||||
except Exception:
|
||||
processor = None
|
||||
|
||||
return {"tokenizer": tokenizer, "processor": processor}
|
||||
@ -145,12 +138,16 @@ def load_model(
|
||||
|
||||
if model_args.mixture_of_depths == "load":
|
||||
model = load_mod_pretrained_model(**init_kwargs)
|
||||
elif model_args.visual_inputs:
|
||||
model = AutoModelForVision2Seq.from_pretrained(**init_kwargs)
|
||||
elif model_args.train_from_scratch:
|
||||
model = AutoModelForCausalLM.from_config(config)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(**init_kwargs)
|
||||
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # assume built-in models
|
||||
load_class = AutoModelForVision2Seq
|
||||
else:
|
||||
load_class = AutoModelForCausalLM
|
||||
|
||||
if model_args.train_from_scratch:
|
||||
model = load_class.from_config(config)
|
||||
else:
|
||||
model = load_class.from_pretrained(**init_kwargs)
|
||||
|
||||
if model_args.mixture_of_depths == "convert":
|
||||
model = convert_pretrained_model_to_mod(model, config, model_args)
|
||||
|
@ -15,7 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
from typing import TYPE_CHECKING, List, Sequence, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers.models
|
||||
@ -28,7 +28,7 @@ from ...extras.logging import get_logger
|
||||
if TYPE_CHECKING:
|
||||
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
|
||||
|
||||
from ...hparams import ModelArguments
|
||||
from ...hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -80,24 +80,74 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
|
||||
self.act = ACT2FN[projector_hidden_act]
|
||||
|
||||
|
||||
def autocast_projector_dtype(
|
||||
model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector"
|
||||
) -> None:
|
||||
def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
|
||||
r"""
|
||||
Casts projector output to half precision for quantized VLMs.
|
||||
"""
|
||||
|
||||
def _mm_projector_forward_post_hook(
|
||||
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
|
||||
) -> "torch.Tensor":
|
||||
return output.to(model_args.compute_dtype)
|
||||
|
||||
if hasattr(model, mm_projector_name) and getattr(model, "quantization_method", None):
|
||||
if getattr(model, "quantization_method", None):
|
||||
if getattr(model.config, "model_type", None) in ["llava", "paligemma"]:
|
||||
mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector")
|
||||
elif getattr(model.config, "model_type", None) == "qwen2_vl":
|
||||
mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger")
|
||||
else:
|
||||
return
|
||||
|
||||
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
|
||||
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
|
||||
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
||||
|
||||
|
||||
def configure_visual_model(config: "PretrainedConfig") -> None:
|
||||
r"""
|
||||
Patches VLMs before loading them.
|
||||
"""
|
||||
if getattr(config, "model_type", None) == "llava": # required for ds zero3 and valuehead models
|
||||
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
|
||||
|
||||
if getattr(config, "is_yi_vl_derived_model", None):
|
||||
logger.info("Detected Yi-VL model, applying projector patch.")
|
||||
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
|
||||
|
||||
|
||||
def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> Set[str]:
|
||||
r"""
|
||||
Freezes vision tower and language model for VLM full/freeze tuning.
|
||||
"""
|
||||
forbidden_modules = set()
|
||||
if getattr(config, "model_type", None) in ["llava", "paligemma"]:
|
||||
if finetuning_args.freeze_vision_tower:
|
||||
forbidden_modules.add("vision_tower")
|
||||
|
||||
if finetuning_args.train_mm_proj_only:
|
||||
forbidden_modules.add("language_model")
|
||||
|
||||
elif getattr(config, "model_type", None) == "qwen2_vl":
|
||||
if finetuning_args.freeze_vision_tower:
|
||||
forbidden_modules.add("visual")
|
||||
|
||||
if finetuning_args.train_mm_proj_only:
|
||||
raise ValueError("Qwen2-VL models do not support `train_mm_proj_only`.")
|
||||
|
||||
return forbidden_modules
|
||||
|
||||
|
||||
def patch_target_modules(
|
||||
config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
|
||||
) -> Union[str, List[str]]:
|
||||
r"""
|
||||
Freezes vision tower for VLM LoRA tuning.
|
||||
"""
|
||||
if not finetuning_args.freeze_vision_tower:
|
||||
return target_modules
|
||||
|
||||
if getattr(config, "model_type", None) in ["llava", "paligemma"]:
|
||||
return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
|
||||
elif getattr(config, "model_type", None) == "qwen2_vl":
|
||||
return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
|
||||
else:
|
||||
return target_modules
|
||||
|
@ -131,11 +131,9 @@ def patch_model(
|
||||
if model_args.resize_vocab:
|
||||
resize_embedding_layer(model, tokenizer)
|
||||
|
||||
if model_args.visual_inputs:
|
||||
autocast_projector_dtype(model, model_args)
|
||||
|
||||
if is_trainable:
|
||||
prepare_model_for_training(model, model_args)
|
||||
autocast_projector_dtype(model, model_args)
|
||||
add_z3_leaf_module(model)
|
||||
|
||||
if not model_args.use_unsloth:
|
||||
|
@ -61,7 +61,6 @@ def run_sft(
|
||||
# Override the decoding parameters of Seq2SeqTrainer
|
||||
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
|
||||
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
||||
training_args.remove_unused_columns = False if model_args.visual_inputs else training_args.remove_unused_columns
|
||||
|
||||
# Metric utils
|
||||
metric_module = {}
|
||||
|
@ -132,7 +132,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
||||
if model_args.export_hub_model_id is not None:
|
||||
tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
|
||||
|
||||
if model_args.visual_inputs and processor is not None:
|
||||
if processor is not None:
|
||||
getattr(processor, "image_processor").save_pretrained(model_args.export_dir)
|
||||
if model_args.export_hub_model_id is not None:
|
||||
getattr(processor, "image_processor").push_to_hub(
|
||||
|
@ -90,7 +90,6 @@ class WebChatModel(ChatModel):
|
||||
template=get("top.template"),
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
visual_inputs=get("top.visual_inputs"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
infer_backend=get("infer.infer_backend"),
|
||||
infer_dtype=get("infer.infer_dtype"),
|
||||
|
@ -122,16 +122,15 @@ def get_prefix(model_name: str) -> str:
|
||||
return model_name.split("-")[0]
|
||||
|
||||
|
||||
def get_model_info(model_name: str) -> Tuple[str, str, bool]:
|
||||
def get_model_info(model_name: str) -> Tuple[str, str]:
|
||||
r"""
|
||||
Gets the necessary information of this model.
|
||||
|
||||
Returns:
|
||||
model_path (str)
|
||||
template (str)
|
||||
visual (bool)
|
||||
"""
|
||||
return get_model_path(model_name), get_template(model_name), get_visual(model_name)
|
||||
return get_model_path(model_name), get_template(model_name)
|
||||
|
||||
|
||||
def get_template(model_name: str) -> str:
|
||||
|
@ -15,6 +15,7 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import get_visual
|
||||
from .chatbot import create_chat_box
|
||||
|
||||
|
||||
@ -64,9 +65,9 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
lambda: ([], []), outputs=[chatbot, messages]
|
||||
).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]])
|
||||
|
||||
engine.manager.get_elem_by_id("top.visual_inputs").change(
|
||||
lambda enabled: gr.Column(visible=enabled),
|
||||
[engine.manager.get_elem_by_id("top.visual_inputs")],
|
||||
engine.manager.get_elem_by_id("top.model_name").change(
|
||||
lambda model_name: gr.Column(visible=get_visual(model_name)),
|
||||
[engine.manager.get_elem_by_id("top.model_name")],
|
||||
[chat_elems["image_box"]],
|
||||
)
|
||||
|
||||
|
@ -48,9 +48,8 @@ def create_top() -> Dict[str, "Component"]:
|
||||
template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=1)
|
||||
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=2)
|
||||
booster = gr.Radio(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto", scale=3)
|
||||
visual_inputs = gr.Checkbox(scale=1)
|
||||
|
||||
model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False).then(
|
||||
model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then(
|
||||
list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
|
||||
)
|
||||
model_name.input(save_config, inputs=[lang, model_name], queue=False)
|
||||
@ -73,5 +72,4 @@ def create_top() -> Dict[str, "Component"]:
|
||||
template=template,
|
||||
rope_scaling=rope_scaling,
|
||||
booster=booster,
|
||||
visual_inputs=visual_inputs,
|
||||
)
|
||||
|
@ -183,20 +183,6 @@ LOCALES = {
|
||||
"label": "부스터",
|
||||
},
|
||||
},
|
||||
"visual_inputs": {
|
||||
"en": {
|
||||
"label": "Visual inputs",
|
||||
},
|
||||
"ru": {
|
||||
"label": "визуальные входы",
|
||||
},
|
||||
"zh": {
|
||||
"label": "图像输入",
|
||||
},
|
||||
"ko": {
|
||||
"label": "시각적 입력",
|
||||
},
|
||||
},
|
||||
"training_stage": {
|
||||
"en": {
|
||||
"label": "Stage",
|
||||
|
@ -75,5 +75,4 @@ class Manager:
|
||||
self._id_to_elem["top.template"],
|
||||
self._id_to_elem["top.rope_scaling"],
|
||||
self._id_to_elem["top.booster"],
|
||||
self._id_to_elem["top.visual_inputs"],
|
||||
}
|
||||
|
@ -116,7 +116,6 @@ class Runner:
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
enable_liger_kernel=(get("top.booster") == "liger_kernel"),
|
||||
visual_inputs=get("top.visual_inputs"),
|
||||
dataset_dir=get("train.dataset_dir"),
|
||||
dataset=",".join(get("train.dataset")),
|
||||
cutoff_len=get("train.cutoff_len"),
|
||||
@ -252,7 +251,6 @@ class Runner:
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
visual_inputs=get("top.visual_inputs"),
|
||||
dataset_dir=get("eval.dataset_dir"),
|
||||
eval_dataset=",".join(get("eval.dataset")),
|
||||
cutoff_len=get("eval.cutoff_len"),
|
||||
|
Loading…
x
Reference in New Issue
Block a user