Former-commit-id: 73ff9c834b069bf8b1bde75cc4daf996746050fa
This commit is contained in:
hiyouga 2024-04-24 05:21:18 +08:00
parent 612ba26c4c
commit 7d89abb1fd

View File

@ -18,7 +18,7 @@ def _get_unsloth_kwargs(
) -> Dict[str, Any]: ) -> Dict[str, Any]:
return { return {
"model_name": model_name_or_path, "model_name": model_name_or_path,
"max_seq_length": model_args.model_max_length, "max_seq_length": model_args.model_max_length or 4096,
"dtype": model_args.compute_dtype, "dtype": model_args.compute_dtype,
"load_in_4bit": model_args.quantization_bit == 4, "load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token, "token": model_args.hf_hub_token,
@ -34,7 +34,7 @@ def load_unsloth_pretrained_model(
config: "PretrainedConfig", model_args: "ModelArguments" config: "PretrainedConfig", model_args: "ModelArguments"
) -> Optional["PreTrainedModel"]: ) -> Optional["PreTrainedModel"]:
r""" r"""
Optionally loads pretrained model with unsloth. Optionally loads pretrained model with unsloth. Used in training.
""" """
from unsloth import FastLanguageModel from unsloth import FastLanguageModel
@ -53,7 +53,7 @@ def get_unsloth_peft_model(
model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any] model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any]
) -> "PreTrainedModel": ) -> "PreTrainedModel":
r""" r"""
Gets the peft model for the pretrained model with unsloth. Gets the peft model for the pretrained model with unsloth. Used in training.
""" """
from unsloth import FastLanguageModel from unsloth import FastLanguageModel
@ -69,12 +69,15 @@ def load_unsloth_peft_model(
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
) -> "PreTrainedModel": ) -> "PreTrainedModel":
r""" r"""
Loads peft model with unsloth. Loads peft model with unsloth. Used in both training and inference.
""" """
from unsloth import FastLanguageModel from unsloth import FastLanguageModel
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args) unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args)
try: try:
if not is_trainable:
unsloth_kwargs["use_gradient_checkpointing"] = False
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError: except NotImplementedError:
raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))