mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 19:30:36 +08:00
@@ -117,6 +117,10 @@ class ModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
|
||||
)
|
||||
use_liger_kernel: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to enable liger kernel for faster training."},
|
||||
)
|
||||
visual_inputs: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."},
|
||||
|
||||
@@ -116,6 +116,9 @@ def _check_extra_dependencies(
|
||||
if model_args.use_unsloth:
|
||||
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
|
||||
|
||||
if model_args.use_liger_kernel:
|
||||
require_version("liger-kernel", "To fix: pip install liger-kernel")
|
||||
|
||||
if model_args.mixture_of_depths is not None:
|
||||
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
|
||||
|
||||
|
||||
48
src/llamafactory/model/model_utils/liger_kernel.py
Normal file
48
src/llamafactory/model/model_utils/liger_kernel.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if not is_trainable or not model_args.use_liger_kernel:
|
||||
return
|
||||
|
||||
if getattr(config, "model_type", None) == "gemma":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_gemma as apply_liger_kernel
|
||||
elif getattr(config, "model_type", None) == "llama":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel
|
||||
elif getattr(config, "model_type", None) == "mistral":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel
|
||||
elif getattr(config, "model_type", None) == "mixtral":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel
|
||||
elif getattr(config, "model_type", None) == "qwen2":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_qwen2 as apply_liger_kernel
|
||||
else:
|
||||
logger.warning("Current model does not support liger kernel.")
|
||||
return
|
||||
|
||||
apply_liger_kernel()
|
||||
logger.info("Liger kernel has been applied to the model.")
|
||||
@@ -27,6 +27,7 @@ from ..extras.misc import infer_optim_dtype
|
||||
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
|
||||
from .model_utils.checkpointing import prepare_model_for_training
|
||||
from .model_utils.embedding import resize_embedding_layer
|
||||
from .model_utils.liger_kernel import configure_liger_kernel
|
||||
from .model_utils.longlora import configure_longlora
|
||||
from .model_utils.moe import add_z3_leaf_module, configure_moe
|
||||
from .model_utils.packing import configure_packing
|
||||
@@ -70,6 +71,7 @@ def patch_config(
|
||||
|
||||
configure_attn_implementation(config, model_args, is_trainable)
|
||||
configure_rope(config, model_args, is_trainable)
|
||||
configure_liger_kernel(config, model_args, is_trainable)
|
||||
configure_longlora(config, model_args, is_trainable)
|
||||
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
||||
configure_moe(config, model_args, is_trainable)
|
||||
|
||||
@@ -47,7 +47,7 @@ def create_top() -> Dict[str, "Component"]:
|
||||
quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes", scale=1)
|
||||
template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=1)
|
||||
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=2)
|
||||
booster = gr.Radio(choices=["auto", "flashattn2", "unsloth"], value="auto", scale=2)
|
||||
booster = gr.Radio(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto", scale=3)
|
||||
visual_inputs = gr.Checkbox(scale=1)
|
||||
|
||||
model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False).then(
|
||||
|
||||
@@ -115,6 +115,7 @@ class Runner:
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
use_liger_kernel=(get("top.booster") == "liger_kernel"),
|
||||
visual_inputs=get("top.visual_inputs"),
|
||||
dataset_dir=get("train.dataset_dir"),
|
||||
dataset=",".join(get("train.dataset")),
|
||||
|
||||
Reference in New Issue
Block a user