diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 9e8c019c..cd2f1867 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import asdict, dataclass, field, fields +from dataclasses import dataclass, field, fields from typing import Any, Dict, Literal, Optional, Union import torch @@ -308,20 +308,18 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments, if self.export_quantization_bit is not None and self.export_quantization_dataset is None: raise ValueError("Quantization dataset is necessary for exporting.") - def to_dict(self) -> Dict[str, Any]: - return asdict(self) - @classmethod - def copyfrom(cls, old_arg: "Self", **kwargs) -> "Self": - arg_dict = old_arg.to_dict() - arg_dict.update(**kwargs) - for attr in fields(cls): - if not attr.init: - arg_dict.pop(attr.name) + def copyfrom(cls, source: "Self", **kwargs) -> "Self": + init_args, lazy_args = {}, {} + for attr in fields(source): + if attr.init: + init_args[attr.name] = getattr(source, attr.name) + else: + lazy_args[attr.name] = getattr(source, attr.name) - new_arg = cls(**arg_dict) - new_arg.compute_dtype = old_arg.compute_dtype - new_arg.device_map = old_arg.device_map - new_arg.model_max_length = old_arg.model_max_length - new_arg.block_diag_attn = old_arg.block_diag_attn - return new_arg + init_args.update(kwargs) + result = cls(**init_args) + for name, value in lazy_args.items(): + setattr(result, name, value) + + return result diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 9e47fb72..b4ef16e8 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -21,6 +21,7 @@ from trl import AutoModelForCausalLMWithValueHead from ..extras.logging import get_logger from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms from .adapter import init_adapter +from .model_utils.liger_kernel import apply_liger_kernel from .model_utils.misc import register_autoclass from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model from .model_utils.unsloth import load_unsloth_pretrained_model @@ -128,6 +129,7 @@ def load_model( init_kwargs = _get_init_kwargs(model_args) config = load_config(model_args) patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) + apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"])) model = None lazy_load = False diff --git a/src/llamafactory/model/model_utils/liger_kernel.py b/src/llamafactory/model/model_utils/liger_kernel.py index 9f9cd20d..e554ccbc 100644 --- a/src/llamafactory/model/model_utils/liger_kernel.py +++ b/src/llamafactory/model/model_utils/liger_kernel.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect from typing import TYPE_CHECKING from ...extras.logging import get_logger @@ -26,7 +27,12 @@ if TYPE_CHECKING: logger = get_logger(__name__) -def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: +def apply_liger_kernel( + config: "PretrainedConfig", + model_args: "ModelArguments", + is_trainable: bool, + require_logits: bool, +) -> None: if not is_trainable or not model_args.enable_liger_kernel: return @@ -51,5 +57,11 @@ def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArgumen logger.warning("Current model does not support liger kernel.") return - apply_liger_kernel() + if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters: + logger.info("Current training stage does not support chunked cross entropy.") + kwargs = {"fused_linear_cross_entropy": False} + else: + kwargs = {} + + apply_liger_kernel(**kwargs) logger.info("Liger kernel has been applied to the model.") diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index e4bb7ac1..06d41af5 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -27,7 +27,6 @@ from ..extras.misc import infer_optim_dtype from .model_utils.attention import configure_attn_implementation, print_attn_implementation from .model_utils.checkpointing import prepare_model_for_training from .model_utils.embedding import resize_embedding_layer -from .model_utils.liger_kernel import configure_liger_kernel from .model_utils.longlora import configure_longlora from .model_utils.moe import add_z3_leaf_module, configure_moe from .model_utils.packing import configure_packing @@ -93,7 +92,6 @@ def patch_config( configure_attn_implementation(config, model_args, is_trainable) configure_rope(config, model_args, is_trainable) - configure_liger_kernel(config, model_args, is_trainable) configure_longlora(config, model_args, is_trainable) configure_quantization(config, tokenizer, model_args, init_kwargs) configure_moe(config, model_args, is_trainable)