From 7d89abb1fd34a716c792eb61ea416cdd3fb8b060 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 24 Apr 2024 05:21:18 +0800 Subject: [PATCH] fix bug Former-commit-id: 73ff9c834b069bf8b1bde75cc4daf996746050fa --- src/llmtuner/model/utils/unsloth.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/llmtuner/model/utils/unsloth.py b/src/llmtuner/model/utils/unsloth.py index 974b41c0..8a16409d 100644 --- a/src/llmtuner/model/utils/unsloth.py +++ b/src/llmtuner/model/utils/unsloth.py @@ -18,7 +18,7 @@ def _get_unsloth_kwargs( ) -> Dict[str, Any]: return { "model_name": model_name_or_path, - "max_seq_length": model_args.model_max_length, + "max_seq_length": model_args.model_max_length or 4096, "dtype": model_args.compute_dtype, "load_in_4bit": model_args.quantization_bit == 4, "token": model_args.hf_hub_token, @@ -34,7 +34,7 @@ def load_unsloth_pretrained_model( config: "PretrainedConfig", model_args: "ModelArguments" ) -> Optional["PreTrainedModel"]: r""" - Optionally loads pretrained model with unsloth. + Optionally loads pretrained model with unsloth. Used in training. """ from unsloth import FastLanguageModel @@ -53,7 +53,7 @@ def get_unsloth_peft_model( model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any] ) -> "PreTrainedModel": r""" - Gets the peft model for the pretrained model with unsloth. + Gets the peft model for the pretrained model with unsloth. Used in training. """ from unsloth import FastLanguageModel @@ -69,12 +69,15 @@ def load_unsloth_peft_model( config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool ) -> "PreTrainedModel": r""" - Loads peft model with unsloth. + Loads peft model with unsloth. Used in both training and inference. """ from unsloth import FastLanguageModel unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args) try: + if not is_trainable: + unsloth_kwargs["use_gradient_checkpointing"] = False + model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) except NotImplementedError: raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))