diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index ec303655..00a4c72c 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -26,7 +26,7 @@ from transformers import HfArgumentParser, Seq2SeqTrainingArguments from transformers.integrations import is_deepspeed_zero3_enabled from transformers.trainer_utils import get_last_checkpoint from transformers.training_args import ParallelMode -from transformers.utils import is_torch_bf16_gpu_available +from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available from transformers.utils.versions import require_version from ..extras.constants import CHECKPOINT_NAMES @@ -215,14 +215,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: ): raise ValueError("Please specify dataset for evaluation.") - if training_args.predict_with_generate and data_args.eval_dataset is None: - raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.") + if training_args.predict_with_generate: + if is_deepspeed_zero3_enabled(): + raise ValueError("`predict_with_generate` is incompatible with DeepSpeed ZeRO-3.") - if training_args.predict_with_generate and finetuning_args.compute_accuracy: - raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.") + if data_args.eval_dataset is None: + raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.") - if training_args.predict_with_generate and is_deepspeed_zero3_enabled(): - raise ValueError("`predict_with_generate` is incompatible with DeepSpeed ZeRO-3.") + if finetuning_args.compute_accuracy: + raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.") if training_args.do_train and model_args.quantization_device_map == "auto": raise ValueError("Cannot use device map for quantized models in training.") @@ -231,7 +232,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA in DeepSpeed ZeRO-3.") if finetuning_args.pure_bf16: - if not is_torch_bf16_gpu_available(): + if not (is_torch_bf16_gpu_available() or (is_torch_npu_available() and torch.npu.is_bf16_supported())): raise ValueError("This device does not support `pure_bf16`.") if is_deepspeed_zero3_enabled(): diff --git a/src/llamafactory/model/model_utils/liger_kernel.py b/src/llamafactory/model/model_utils/liger_kernel.py index 61de0be0..e40169ad 100644 --- a/src/llamafactory/model/model_utils/liger_kernel.py +++ b/src/llamafactory/model/model_utils/liger_kernel.py @@ -32,12 +32,16 @@ def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArgumen if getattr(config, "model_type", None) == "gemma": from liger_kernel.transformers import apply_liger_kernel_to_gemma as apply_liger_kernel + elif getattr(config, "model_type", None) == "gemma2": + from liger_kernel.transformers import apply_liger_kernel_to_gemma2 as apply_liger_kernel elif getattr(config, "model_type", None) == "llama": from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel elif getattr(config, "model_type", None) == "mistral": from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel elif getattr(config, "model_type", None) == "mixtral": from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel + elif getattr(config, "model_type", None) == "phi3": + from liger_kernel.transformers import apply_liger_kernel_to_phi3 as apply_liger_kernel elif getattr(config, "model_type", None) == "qwen2": from liger_kernel.transformers import apply_liger_kernel_to_qwen2 as apply_liger_kernel else: