mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
fix bug
Former-commit-id: 73ff9c834b069bf8b1bde75cc4daf996746050fa
This commit is contained in:
parent
612ba26c4c
commit
7d89abb1fd
@ -18,7 +18,7 @@ def _get_unsloth_kwargs(
|
|||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"model_name": model_name_or_path,
|
"model_name": model_name_or_path,
|
||||||
"max_seq_length": model_args.model_max_length,
|
"max_seq_length": model_args.model_max_length or 4096,
|
||||||
"dtype": model_args.compute_dtype,
|
"dtype": model_args.compute_dtype,
|
||||||
"load_in_4bit": model_args.quantization_bit == 4,
|
"load_in_4bit": model_args.quantization_bit == 4,
|
||||||
"token": model_args.hf_hub_token,
|
"token": model_args.hf_hub_token,
|
||||||
@ -34,7 +34,7 @@ def load_unsloth_pretrained_model(
|
|||||||
config: "PretrainedConfig", model_args: "ModelArguments"
|
config: "PretrainedConfig", model_args: "ModelArguments"
|
||||||
) -> Optional["PreTrainedModel"]:
|
) -> Optional["PreTrainedModel"]:
|
||||||
r"""
|
r"""
|
||||||
Optionally loads pretrained model with unsloth.
|
Optionally loads pretrained model with unsloth. Used in training.
|
||||||
"""
|
"""
|
||||||
from unsloth import FastLanguageModel
|
from unsloth import FastLanguageModel
|
||||||
|
|
||||||
@ -53,7 +53,7 @@ def get_unsloth_peft_model(
|
|||||||
model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any]
|
model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any]
|
||||||
) -> "PreTrainedModel":
|
) -> "PreTrainedModel":
|
||||||
r"""
|
r"""
|
||||||
Gets the peft model for the pretrained model with unsloth.
|
Gets the peft model for the pretrained model with unsloth. Used in training.
|
||||||
"""
|
"""
|
||||||
from unsloth import FastLanguageModel
|
from unsloth import FastLanguageModel
|
||||||
|
|
||||||
@ -69,12 +69,15 @@ def load_unsloth_peft_model(
|
|||||||
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
|
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
|
||||||
) -> "PreTrainedModel":
|
) -> "PreTrainedModel":
|
||||||
r"""
|
r"""
|
||||||
Loads peft model with unsloth.
|
Loads peft model with unsloth. Used in both training and inference.
|
||||||
"""
|
"""
|
||||||
from unsloth import FastLanguageModel
|
from unsloth import FastLanguageModel
|
||||||
|
|
||||||
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args)
|
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args)
|
||||||
try:
|
try:
|
||||||
|
if not is_trainable:
|
||||||
|
unsloth_kwargs["use_gradient_checkpointing"] = False
|
||||||
|
|
||||||
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
|
raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user