mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-06 05:32:50 +08:00
support unsloth generate
Former-commit-id: b1deb0a0b920645884e58f8206b1842c144c1c52
This commit is contained in:
parent
8465e54d38
commit
c0afc4074f
@ -7,10 +7,11 @@ from transformers.integrations import is_deepspeed_zero3_enabled
|
|||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from .utils.misc import find_all_linear_modules, find_expanded_modules
|
from .utils.misc import find_all_linear_modules, find_expanded_modules
|
||||||
from .utils.quantization import QuantizationMethod
|
from .utils.quantization import QuantizationMethod
|
||||||
|
from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers import PretrainedConfig, PreTrainedModel
|
||||||
|
|
||||||
from ..hparams import FinetuningArguments, ModelArguments
|
from ..hparams import FinetuningArguments, ModelArguments
|
||||||
|
|
||||||
@ -19,7 +20,11 @@ logger = get_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def init_adapter(
|
def init_adapter(
|
||||||
model: "PreTrainedModel", model_args: "ModelArguments", finetuning_args: "FinetuningArguments", is_trainable: bool
|
config: "PretrainedConfig",
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
is_trainable: bool,
|
||||||
) -> "PreTrainedModel":
|
) -> "PreTrainedModel":
|
||||||
r"""
|
r"""
|
||||||
Initializes the adapters.
|
Initializes the adapters.
|
||||||
@ -106,6 +111,10 @@ def init_adapter(
|
|||||||
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
|
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
|
||||||
is_mergeable = False
|
is_mergeable = False
|
||||||
|
|
||||||
|
if model_args.use_unsloth:
|
||||||
|
assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter."
|
||||||
|
is_mergeable = False
|
||||||
|
|
||||||
if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):
|
if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):
|
||||||
adapter_to_merge = model_args.adapter_name_or_path[:-1]
|
adapter_to_merge = model_args.adapter_name_or_path[:-1]
|
||||||
adapter_to_resume = model_args.adapter_name_or_path[-1]
|
adapter_to_resume = model_args.adapter_name_or_path[-1]
|
||||||
@ -122,9 +131,15 @@ def init_adapter(
|
|||||||
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
|
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
|
||||||
|
|
||||||
if adapter_to_resume is not None: # resume lora training
|
if adapter_to_resume is not None: # resume lora training
|
||||||
model = PeftModel.from_pretrained(
|
if model_args.use_unsloth:
|
||||||
model, adapter_to_resume, is_trainable=is_trainable, offload_folder=model_args.offload_folder
|
model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable)
|
||||||
)
|
else:
|
||||||
|
model = PeftModel.from_pretrained(
|
||||||
|
model,
|
||||||
|
adapter_to_resume,
|
||||||
|
is_trainable=is_trainable,
|
||||||
|
offload_folder=model_args.offload_folder,
|
||||||
|
)
|
||||||
|
|
||||||
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
||||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||||
@ -152,14 +167,8 @@ def init_adapter(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if model_args.use_unsloth:
|
if model_args.use_unsloth:
|
||||||
from unsloth import FastLanguageModel # type: ignore
|
print(model)
|
||||||
|
model = get_unsloth_peft_model(model, model_args, peft_kwargs)
|
||||||
unsloth_peft_kwargs = {
|
|
||||||
"model": model,
|
|
||||||
"max_seq_length": model_args.model_max_length,
|
|
||||||
"use_gradient_checkpointing": "unsloth",
|
|
||||||
}
|
|
||||||
model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
|
|
||||||
else:
|
else:
|
||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
task_type=TaskType.CAUSAL_LM,
|
task_type=TaskType.CAUSAL_LM,
|
||||||
|
@ -3,12 +3,13 @@ from typing import TYPE_CHECKING, Any, Dict
|
|||||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from ..extras.constants import MOD_SUPPORTED_MODELS
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
|
from ..extras.misc import count_parameters, try_download_model_from_ms
|
||||||
from .adapter import init_adapter
|
from .adapter import init_adapter
|
||||||
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
|
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
|
||||||
from .utils.misc import load_valuehead_params, register_autoclass
|
from .utils.misc import load_valuehead_params, register_autoclass
|
||||||
|
from .utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
||||||
|
from .utils.unsloth import load_unsloth_pretrained_model
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -83,54 +84,30 @@ def load_model(
|
|||||||
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
|
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
|
||||||
|
|
||||||
model = None
|
model = None
|
||||||
if is_trainable and model_args.use_unsloth:
|
lazy_load = False
|
||||||
from unsloth import FastLanguageModel # type: ignore
|
if model_args.use_unsloth:
|
||||||
|
if model_args.adapter_name_or_path is not None:
|
||||||
|
lazy_load = True
|
||||||
|
elif is_trainable:
|
||||||
|
model = load_unsloth_pretrained_model(config, model_args)
|
||||||
|
|
||||||
unsloth_kwargs = {
|
if model is None and not lazy_load:
|
||||||
"model_name": model_args.model_name_or_path,
|
|
||||||
"max_seq_length": model_args.model_max_length,
|
|
||||||
"dtype": model_args.compute_dtype,
|
|
||||||
"load_in_4bit": model_args.quantization_bit == 4,
|
|
||||||
"token": model_args.hf_hub_token,
|
|
||||||
"device_map": {"": get_current_device()},
|
|
||||||
"rope_scaling": getattr(config, "rope_scaling", None),
|
|
||||||
"fix_tokenizer": False,
|
|
||||||
"trust_remote_code": True,
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
|
||||||
except NotImplementedError:
|
|
||||||
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
|
|
||||||
model_args.use_unsloth = False
|
|
||||||
|
|
||||||
if model_args.adapter_name_or_path:
|
|
||||||
model_args.adapter_name_or_path = None
|
|
||||||
logger.warning("Unsloth does not support loading adapters.")
|
|
||||||
|
|
||||||
if model is None:
|
|
||||||
init_kwargs["config"] = config
|
init_kwargs["config"] = config
|
||||||
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
|
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
|
||||||
|
|
||||||
if model_args.mixture_of_depths == "load":
|
if model_args.mixture_of_depths == "load":
|
||||||
from MoD import AutoMoDModelForCausalLM
|
model = load_mod_pretrained_model(**init_kwargs)
|
||||||
|
|
||||||
model = AutoMoDModelForCausalLM.from_pretrained(**init_kwargs)
|
|
||||||
else:
|
else:
|
||||||
model = AutoModelForCausalLM.from_pretrained(**init_kwargs)
|
model = AutoModelForCausalLM.from_pretrained(**init_kwargs)
|
||||||
|
|
||||||
if model_args.mixture_of_depths == "convert":
|
if model_args.mixture_of_depths == "convert":
|
||||||
from MoD import apply_mod_to_hf
|
model = convert_pretrained_model_to_mod(model, config, model_args)
|
||||||
|
|
||||||
if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS:
|
if not lazy_load:
|
||||||
raise ValueError("Current model is not supported by mixture-of-depth.")
|
patch_model(model, tokenizer, model_args, is_trainable)
|
||||||
|
register_autoclass(config, model, tokenizer)
|
||||||
|
|
||||||
model = apply_mod_to_hf(model)
|
model = init_adapter(config, model, model_args, finetuning_args, is_trainable)
|
||||||
model = model.to(model_args.compute_dtype)
|
|
||||||
|
|
||||||
patch_model(model, tokenizer, model_args, is_trainable)
|
|
||||||
register_autoclass(config, model, tokenizer)
|
|
||||||
|
|
||||||
model = init_adapter(model, model_args, finetuning_args, is_trainable)
|
|
||||||
|
|
||||||
if add_valuehead:
|
if add_valuehead:
|
||||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||||
|
28
src/llmtuner/model/utils/mod.py
Normal file
28
src/llmtuner/model/utils/mod.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...extras.constants import MOD_SUPPORTED_MODELS
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PretrainedConfig, PreTrainedModel
|
||||||
|
|
||||||
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
|
def load_mod_pretrained_model(**init_kwargs) -> "PreTrainedModel":
|
||||||
|
from MoD import AutoMoDModelForCausalLM
|
||||||
|
|
||||||
|
return AutoMoDModelForCausalLM.from_pretrained(**init_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_pretrained_model_to_mod(
|
||||||
|
model: "PreTrainedModel", config: "PretrainedConfig", model_args: "ModelArguments"
|
||||||
|
) -> "PreTrainedModel":
|
||||||
|
from MoD import apply_mod_to_hf
|
||||||
|
|
||||||
|
if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS:
|
||||||
|
raise ValueError("Current model is not supported by mixture-of-depth.")
|
||||||
|
|
||||||
|
model = apply_mod_to_hf(model)
|
||||||
|
model = model.to(model_args.compute_dtype)
|
||||||
|
return model
|
85
src/llmtuner/model/utils/unsloth.py
Normal file
85
src/llmtuner/model/utils/unsloth.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||||
|
|
||||||
|
from ...extras.logging import get_logger
|
||||||
|
from ...extras.misc import get_current_device
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PretrainedConfig, PreTrainedModel
|
||||||
|
|
||||||
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_unsloth_kwargs(
|
||||||
|
config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments"
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"model_name": model_name_or_path,
|
||||||
|
"max_seq_length": model_args.model_max_length,
|
||||||
|
"dtype": model_args.compute_dtype,
|
||||||
|
"load_in_4bit": model_args.quantization_bit == 4,
|
||||||
|
"token": model_args.hf_hub_token,
|
||||||
|
"device_map": {"": get_current_device()},
|
||||||
|
"rope_scaling": getattr(config, "rope_scaling", None),
|
||||||
|
"fix_tokenizer": False,
|
||||||
|
"trust_remote_code": True,
|
||||||
|
"use_gradient_checkpointing": "unsloth",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_unsloth_pretrained_model(
|
||||||
|
config: "PretrainedConfig", model_args: "ModelArguments"
|
||||||
|
) -> Optional["PreTrainedModel"]:
|
||||||
|
r"""
|
||||||
|
Optionally loads pretrained model with unsloth.
|
||||||
|
"""
|
||||||
|
from unsloth import FastLanguageModel
|
||||||
|
|
||||||
|
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args)
|
||||||
|
try:
|
||||||
|
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
||||||
|
except NotImplementedError:
|
||||||
|
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
|
||||||
|
model = None
|
||||||
|
model_args.use_unsloth = False
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_unsloth_peft_model(
|
||||||
|
model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any]
|
||||||
|
) -> "PreTrainedModel":
|
||||||
|
r"""
|
||||||
|
Gets the peft model for the pretrained model with unsloth.
|
||||||
|
"""
|
||||||
|
from unsloth import FastLanguageModel
|
||||||
|
|
||||||
|
unsloth_peft_kwargs = {
|
||||||
|
"model": model,
|
||||||
|
"max_seq_length": model_args.model_max_length,
|
||||||
|
"use_gradient_checkpointing": "unsloth",
|
||||||
|
}
|
||||||
|
return FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def load_unsloth_peft_model(
|
||||||
|
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
|
||||||
|
) -> "PreTrainedModel":
|
||||||
|
r"""
|
||||||
|
Loads peft model with unsloth.
|
||||||
|
"""
|
||||||
|
from unsloth import FastLanguageModel
|
||||||
|
|
||||||
|
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path, model_args)
|
||||||
|
try:
|
||||||
|
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
||||||
|
except NotImplementedError:
|
||||||
|
raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
|
||||||
|
|
||||||
|
if not is_trainable:
|
||||||
|
FastLanguageModel.for_inference(model)
|
||||||
|
|
||||||
|
return model
|
@ -61,6 +61,9 @@ def create_modelcard_and_push(
|
|||||||
if data_args.dataset is not None:
|
if data_args.dataset is not None:
|
||||||
kwargs["dataset"] = [dataset.strip() for dataset in data_args.dataset.split(",")]
|
kwargs["dataset"] = [dataset.strip() for dataset in data_args.dataset.split(",")]
|
||||||
|
|
||||||
|
if model_args.use_unsloth:
|
||||||
|
kwargs["tags"] = kwargs["tags"] + ["unsloth"]
|
||||||
|
|
||||||
if not training_args.do_train:
|
if not training_args.do_train:
|
||||||
pass
|
pass
|
||||||
elif training_args.push_to_hub:
|
elif training_args.push_to_hub:
|
||||||
|
@ -138,7 +138,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
|
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
|
||||||
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1)
|
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1)
|
||||||
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01)
|
lora_dropout = gr.Slider(value=0, minimum=0, maximum=1, step=0.01)
|
||||||
loraplus_lr_ratio = gr.Slider(value=0, minimum=0, maximum=64, step=0.01)
|
loraplus_lr_ratio = gr.Slider(value=0, minimum=0, maximum=64, step=0.01)
|
||||||
create_new_adapter = gr.Checkbox()
|
create_new_adapter = gr.Checkbox()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user