From c0afc4074f794261711ad81f406cfd9b4d6b922c Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 24 Apr 2024 04:46:53 +0800 Subject: [PATCH] support unsloth generate Former-commit-id: b1deb0a0b920645884e58f8206b1842c144c1c52 --- src/llmtuner/model/adapter.py | 35 +++++++---- src/llmtuner/model/loader.py | 55 +++++------------ src/llmtuner/model/utils/mod.py | 28 +++++++++ src/llmtuner/model/utils/unsloth.py | 85 ++++++++++++++++++++++++++ src/llmtuner/train/utils.py | 3 + src/llmtuner/webui/components/train.py | 2 +- 6 files changed, 155 insertions(+), 53 deletions(-) create mode 100644 src/llmtuner/model/utils/mod.py create mode 100644 src/llmtuner/model/utils/unsloth.py diff --git a/src/llmtuner/model/adapter.py b/src/llmtuner/model/adapter.py index efc63cde..d8d8eaf0 100644 --- a/src/llmtuner/model/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -7,10 +7,11 @@ from transformers.integrations import is_deepspeed_zero3_enabled from ..extras.logging import get_logger from .utils.misc import find_all_linear_modules, find_expanded_modules from .utils.quantization import QuantizationMethod +from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model if TYPE_CHECKING: - from transformers.modeling_utils import PreTrainedModel + from transformers import PretrainedConfig, PreTrainedModel from ..hparams import FinetuningArguments, ModelArguments @@ -19,7 +20,11 @@ logger = get_logger(__name__) def init_adapter( - model: "PreTrainedModel", model_args: "ModelArguments", finetuning_args: "FinetuningArguments", is_trainable: bool + config: "PretrainedConfig", + model: "PreTrainedModel", + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + is_trainable: bool, ) -> "PreTrainedModel": r""" Initializes the adapters. @@ -106,6 +111,10 @@ def init_adapter( assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3." is_mergeable = False + if model_args.use_unsloth: + assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter." + is_mergeable = False + if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable): adapter_to_merge = model_args.adapter_name_or_path[:-1] adapter_to_resume = model_args.adapter_name_or_path[-1] @@ -122,9 +131,15 @@ def init_adapter( logger.info("Merged {} adapter(s).".format(len(adapter_to_merge))) if adapter_to_resume is not None: # resume lora training - model = PeftModel.from_pretrained( - model, adapter_to_resume, is_trainable=is_trainable, offload_folder=model_args.offload_folder - ) + if model_args.use_unsloth: + model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable) + else: + model = PeftModel.from_pretrained( + model, + adapter_to_resume, + is_trainable=is_trainable, + offload_folder=model_args.offload_folder, + ) if is_trainable and adapter_to_resume is None: # create new lora weights while training if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": @@ -152,14 +167,8 @@ def init_adapter( } if model_args.use_unsloth: - from unsloth import FastLanguageModel # type: ignore - - unsloth_peft_kwargs = { - "model": model, - "max_seq_length": model_args.model_max_length, - "use_gradient_checkpointing": "unsloth", - } - model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs) + print(model) + model = get_unsloth_peft_model(model, model_args, peft_kwargs) else: lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index b8558542..06405219 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -3,12 +3,13 @@ from typing import TYPE_CHECKING, Any, Dict from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from trl import AutoModelForCausalLMWithValueHead -from ..extras.constants import MOD_SUPPORTED_MODELS from ..extras.logging import get_logger -from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms +from ..extras.misc import count_parameters, try_download_model_from_ms from .adapter import init_adapter from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model from .utils.misc import load_valuehead_params, register_autoclass +from .utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model +from .utils.unsloth import load_unsloth_pretrained_model if TYPE_CHECKING: @@ -83,54 +84,30 @@ def load_model( patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) model = None - if is_trainable and model_args.use_unsloth: - from unsloth import FastLanguageModel # type: ignore + lazy_load = False + if model_args.use_unsloth: + if model_args.adapter_name_or_path is not None: + lazy_load = True + elif is_trainable: + model = load_unsloth_pretrained_model(config, model_args) - unsloth_kwargs = { - "model_name": model_args.model_name_or_path, - "max_seq_length": model_args.model_max_length, - "dtype": model_args.compute_dtype, - "load_in_4bit": model_args.quantization_bit == 4, - "token": model_args.hf_hub_token, - "device_map": {"": get_current_device()}, - "rope_scaling": getattr(config, "rope_scaling", None), - "fix_tokenizer": False, - "trust_remote_code": True, - } - try: - model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) - except NotImplementedError: - logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) - model_args.use_unsloth = False - - if model_args.adapter_name_or_path: - model_args.adapter_name_or_path = None - logger.warning("Unsloth does not support loading adapters.") - - if model is None: + if model is None and not lazy_load: init_kwargs["config"] = config init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path if model_args.mixture_of_depths == "load": - from MoD import AutoMoDModelForCausalLM - - model = AutoMoDModelForCausalLM.from_pretrained(**init_kwargs) + model = load_mod_pretrained_model(**init_kwargs) else: model = AutoModelForCausalLM.from_pretrained(**init_kwargs) if model_args.mixture_of_depths == "convert": - from MoD import apply_mod_to_hf + model = convert_pretrained_model_to_mod(model, config, model_args) - if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS: - raise ValueError("Current model is not supported by mixture-of-depth.") + if not lazy_load: + patch_model(model, tokenizer, model_args, is_trainable) + register_autoclass(config, model, tokenizer) - model = apply_mod_to_hf(model) - model = model.to(model_args.compute_dtype) - - patch_model(model, tokenizer, model_args, is_trainable) - register_autoclass(config, model, tokenizer) - - model = init_adapter(model, model_args, finetuning_args, is_trainable) + model = init_adapter(config, model, model_args, finetuning_args, is_trainable) if add_valuehead: model = AutoModelForCausalLMWithValueHead.from_pretrained(model) diff --git a/src/llmtuner/model/utils/mod.py b/src/llmtuner/model/utils/mod.py new file mode 100644 index 00000000..5708a1a8 --- /dev/null +++ b/src/llmtuner/model/utils/mod.py @@ -0,0 +1,28 @@ +from typing import TYPE_CHECKING + +from ...extras.constants import MOD_SUPPORTED_MODELS + + +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedModel + + from ...hparams import ModelArguments + + +def load_mod_pretrained_model(**init_kwargs) -> "PreTrainedModel": + from MoD import AutoMoDModelForCausalLM + + return AutoMoDModelForCausalLM.from_pretrained(**init_kwargs) + + +def convert_pretrained_model_to_mod( + model: "PreTrainedModel", config: "PretrainedConfig", model_args: "ModelArguments" +) -> "PreTrainedModel": + from MoD import apply_mod_to_hf + + if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS: + raise ValueError("Current model is not supported by mixture-of-depth.") + + model = apply_mod_to_hf(model) + model = model.to(model_args.compute_dtype) + return model diff --git a/src/llmtuner/model/utils/unsloth.py b/src/llmtuner/model/utils/unsloth.py new file mode 100644 index 00000000..6c5f506f --- /dev/null +++ b/src/llmtuner/model/utils/unsloth.py @@ -0,0 +1,85 @@ +from typing import TYPE_CHECKING, Any, Dict, Optional + +from ...extras.logging import get_logger +from ...extras.misc import get_current_device + + +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedModel + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) + + +def _get_unsloth_kwargs( + config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments" +) -> Dict[str, Any]: + return { + "model_name": model_name_or_path, + "max_seq_length": model_args.model_max_length, + "dtype": model_args.compute_dtype, + "load_in_4bit": model_args.quantization_bit == 4, + "token": model_args.hf_hub_token, + "device_map": {"": get_current_device()}, + "rope_scaling": getattr(config, "rope_scaling", None), + "fix_tokenizer": False, + "trust_remote_code": True, + "use_gradient_checkpointing": "unsloth", + } + + +def load_unsloth_pretrained_model( + config: "PretrainedConfig", model_args: "ModelArguments" +) -> Optional["PreTrainedModel"]: + r""" + Optionally loads pretrained model with unsloth. + """ + from unsloth import FastLanguageModel + + unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args) + try: + model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) + except NotImplementedError: + logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) + model = None + model_args.use_unsloth = False + + return model + + +def get_unsloth_peft_model( + model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any] +) -> "PreTrainedModel": + r""" + Gets the peft model for the pretrained model with unsloth. + """ + from unsloth import FastLanguageModel + + unsloth_peft_kwargs = { + "model": model, + "max_seq_length": model_args.model_max_length, + "use_gradient_checkpointing": "unsloth", + } + return FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs) + + +def load_unsloth_peft_model( + config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool +) -> "PreTrainedModel": + r""" + Loads peft model with unsloth. + """ + from unsloth import FastLanguageModel + + unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path, model_args) + try: + model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) + except NotImplementedError: + raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) + + if not is_trainable: + FastLanguageModel.for_inference(model) + + return model diff --git a/src/llmtuner/train/utils.py b/src/llmtuner/train/utils.py index fa9e36e5..27dc8eb3 100644 --- a/src/llmtuner/train/utils.py +++ b/src/llmtuner/train/utils.py @@ -61,6 +61,9 @@ def create_modelcard_and_push( if data_args.dataset is not None: kwargs["dataset"] = [dataset.strip() for dataset in data_args.dataset.split(",")] + if model_args.use_unsloth: + kwargs["tags"] = kwargs["tags"] + ["unsloth"] + if not training_args.do_train: pass elif training_args.push_to_hub: diff --git a/src/llmtuner/webui/components/train.py b/src/llmtuner/webui/components/train.py index 0f425bc9..7dc324af 100644 --- a/src/llmtuner/webui/components/train.py +++ b/src/llmtuner/webui/components/train.py @@ -138,7 +138,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Row(): lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1) lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1) - lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01) + lora_dropout = gr.Slider(value=0, minimum=0, maximum=1, step=0.01) loraplus_lr_ratio = gr.Slider(value=0, minimum=0, maximum=64, step=0.01) create_new_adapter = gr.Checkbox()