Former-commit-id: 38e164fe4aaea6f0baf121a720291ca42643ba8c
This commit is contained in:
hiyouga 2024-04-24 05:21:18 +08:00
parent ad24a2a0c9
commit 94c8219575

View File

@ -18,7 +18,7 @@ def _get_unsloth_kwargs(
) -> Dict[str, Any]: ) -> Dict[str, Any]:
return { return {
"model_name": model_name_or_path, "model_name": model_name_or_path,
"max_seq_length": model_args.model_max_length, "max_seq_length": model_args.model_max_length or 4096,
"dtype": model_args.compute_dtype, "dtype": model_args.compute_dtype,
"load_in_4bit": model_args.quantization_bit == 4, "load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token, "token": model_args.hf_hub_token,
@ -34,7 +34,7 @@ def load_unsloth_pretrained_model(
config: "PretrainedConfig", model_args: "ModelArguments" config: "PretrainedConfig", model_args: "ModelArguments"
) -> Optional["PreTrainedModel"]: ) -> Optional["PreTrainedModel"]:
r""" r"""
Optionally loads pretrained model with unsloth. Optionally loads pretrained model with unsloth. Used in training.
""" """
from unsloth import FastLanguageModel from unsloth import FastLanguageModel
@ -53,7 +53,7 @@ def get_unsloth_peft_model(
model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any] model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any]
) -> "PreTrainedModel": ) -> "PreTrainedModel":
r""" r"""
Gets the peft model for the pretrained model with unsloth. Gets the peft model for the pretrained model with unsloth. Used in training.
""" """
from unsloth import FastLanguageModel from unsloth import FastLanguageModel
@ -69,12 +69,15 @@ def load_unsloth_peft_model(
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
) -> "PreTrainedModel": ) -> "PreTrainedModel":
r""" r"""
Loads peft model with unsloth. Loads peft model with unsloth. Used in both training and inference.
""" """
from unsloth import FastLanguageModel from unsloth import FastLanguageModel
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args) unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args)
try: try:
if not is_trainable:
unsloth_kwargs["use_gradient_checkpointing"] = False
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError: except NotImplementedError:
raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))