From 8465e54d3897ed5c90ba71123d5c628330905faa Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 24 Apr 2024 03:02:23 +0800 Subject: [PATCH] refactor patcher Former-commit-id: aa2b79eb23c60825e6601b0b8cc6b59e3f566b2d --- src/llmtuner/extras/constants.py | 2 + src/llmtuner/model/__init__.py | 2 +- src/llmtuner/model/adapter.py | 3 +- src/llmtuner/model/loader.py | 2 +- src/llmtuner/model/patcher.py | 325 +----------------- .../patches => model/utils}/__init__.py | 0 src/llmtuner/model/utils/attention.py | 55 +++ src/llmtuner/model/utils/checkpointing.py | 94 +++++ src/llmtuner/model/utils/embedding.py | 56 +++ .../utils/longlora.py} | 151 +++++++- .../model/{utils.py => utils/misc.py} | 74 +--- src/llmtuner/model/utils/moe.py | 39 +++ src/llmtuner/model/utils/quantization.py | 146 ++++++++ src/llmtuner/model/utils/rope.py | 43 +++ 14 files changed, 598 insertions(+), 394 deletions(-) rename src/llmtuner/{extras/patches => model/utils}/__init__.py (100%) create mode 100644 src/llmtuner/model/utils/attention.py create mode 100644 src/llmtuner/model/utils/checkpointing.py create mode 100644 src/llmtuner/model/utils/embedding.py rename src/llmtuner/{extras/patches/llama_patch.py => model/utils/longlora.py} (58%) rename src/llmtuner/model/{utils.py => utils/misc.py} (61%) create mode 100644 src/llmtuner/model/utils/moe.py create mode 100644 src/llmtuner/model/utils/quantization.py create mode 100644 src/llmtuner/model/utils/rope.py diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index 38d715f5..0a29f971 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -47,6 +47,8 @@ TRAINING_STAGES = { STAGES_USE_PAIR_DATA = ["rm", "dpo", "orpo"] +SUPPORTED_CLASS_FOR_S2ATTN = ["llama"] + V_HEAD_WEIGHTS_NAME = "value_head.bin" V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors" diff --git a/src/llmtuner/model/__init__.py b/src/llmtuner/model/__init__.py index e0b1c9cd..1824f084 100644 --- a/src/llmtuner/model/__init__.py +++ b/src/llmtuner/model/__init__.py @@ -1,5 +1,5 @@ from .loader import load_config, load_model, load_tokenizer -from .utils import find_all_linear_modules, load_valuehead_params +from .utils.misc import find_all_linear_modules, load_valuehead_params __all__ = [ diff --git a/src/llmtuner/model/adapter.py b/src/llmtuner/model/adapter.py index f73666d5..efc63cde 100644 --- a/src/llmtuner/model/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -5,7 +5,8 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model from transformers.integrations import is_deepspeed_zero3_enabled from ..extras.logging import get_logger -from .utils import QuantizationMethod, find_all_linear_modules, find_expanded_modules +from .utils.misc import find_all_linear_modules, find_expanded_modules +from .utils.quantization import QuantizationMethod if TYPE_CHECKING: diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 57f5a763..b8558542 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -8,7 +8,7 @@ from ..extras.logging import get_logger from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms from .adapter import init_adapter from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model -from .utils import load_valuehead_params, register_autoclass +from .utils.misc import load_valuehead_params, register_autoclass if TYPE_CHECKING: diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 6c79992a..c0166a8a 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -1,23 +1,20 @@ -import math -import os -import random -from contextlib import nullcontext from types import MethodType -from typing import TYPE_CHECKING, Any, Dict, List, Tuple +from typing import TYPE_CHECKING, Any, Dict import torch -from datasets import load_dataset from peft import PeftModel -from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase +from transformers import PreTrainedModel, PreTrainedTokenizerBase from transformers.integrations import is_deepspeed_zero3_enabled -from transformers.utils.versions import require_version -from ..extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES from ..extras.logging import get_logger -from ..extras.misc import get_current_device, infer_optim_dtype -from ..extras.packages import is_flash_attn2_available, is_sdpa_available -from ..extras.patches.llama_patch import apply_llama_patch -from .utils import QuantizationMethod, add_z3_leaf_module, gradient_checkpointing_enable +from ..extras.misc import infer_optim_dtype +from .utils.attention import configure_attn_implementation, print_attn_implementation +from .utils.checkpointing import prepare_model_for_training +from .utils.embedding import resize_embedding_layer +from .utils.longlora import configure_longlora +from .utils.moe import add_z3_leaf_module +from .utils.quantization import configure_quantization +from .utils.rope import configure_rope if TYPE_CHECKING: @@ -28,282 +25,6 @@ if TYPE_CHECKING: logger = get_logger(__name__) -SUPPORTED_CLASS_FOR_S2ATTN = ["llama"] - - -def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]: - r""" - Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133 - TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600 - """ - if os.path.isfile(model_args.export_quantization_dataset): - data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None) - data_files = model_args.export_quantization_dataset - else: - data_path = model_args.export_quantization_dataset - data_files = None - - dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir) - maxlen = model_args.export_quantization_maxlen - - samples = [] - for _ in range(model_args.export_quantization_nsamples): - while True: - sample_idx = random.randint(0, len(dataset) - 1) - sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt") - if sample["input_ids"].size(1) >= maxlen: - break # TODO: fix large maxlen - - word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1) - input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen] - samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True)) - - return samples - - -def _configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None: - if model_args.flash_attn == "auto": - return - - elif model_args.flash_attn == "off": - requested_attn_implementation = "eager" - - elif model_args.flash_attn == "sdpa": - if not is_sdpa_available(): - logger.warning("Torch>=2.1.1 is required for SDPA attention.") - return - - requested_attn_implementation = "sdpa" - elif model_args.flash_attn == "fa2": - if not is_flash_attn2_available(): - logger.warning("FlashAttention-2 is not installed.") - return - - requested_attn_implementation = "flash_attention_2" - else: - raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn)) - - if getattr(config, "model_type", None) == "internlm2": # special case for custom models - setattr(config, "attn_implementation", requested_attn_implementation) - else: - setattr(config, "_attn_implementation", requested_attn_implementation) - - -def _print_attn_implementation(config: "PretrainedConfig") -> None: - if getattr(config, "model_type", None) == "internlm2": # special case for custom models - attn_implementation = getattr(config, "attn_implementation", None) - else: - attn_implementation = getattr(config, "_attn_implementation", None) - - if attn_implementation == "flash_attention_2": - logger.info("Using FlashAttention-2 for faster training and inference.") - elif attn_implementation == "sdpa": - logger.info("Using torch SDPA for faster training and inference.") - else: - logger.info("Using vanilla Attention implementation.") - - -def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: - if model_args.rope_scaling is None: - return - - if not hasattr(config, "rope_scaling"): - logger.warning("Current model does not support RoPE scaling.") - return - - if is_trainable: - if model_args.rope_scaling == "dynamic": - logger.warning( - "Dynamic NTK scaling may not work well with fine-tuning. " - "See: https://github.com/huggingface/transformers/pull/24653" - ) - - current_max_length = getattr(config, "max_position_embeddings", None) - if current_max_length and model_args.model_max_length > current_max_length: - scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length)) - else: - logger.warning("Input length is smaller than max length. Consider increase input length.") - scaling_factor = 1.0 - else: - scaling_factor = 2.0 - - setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor}) - logger.info( - "Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor) - ) - - -def _configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: - if not is_trainable or not model_args.shift_attn: - return - - if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN: - setattr(config, "group_size_ratio", 0.25) - apply_llama_patch() - logger.info("Using shift short attention with group_size_ratio=1/4.") - else: - logger.warning("Current model does not support shift short attention.") - - -def _configure_quantization( - config: "PretrainedConfig", - tokenizer: "PreTrainedTokenizer", - model_args: "ModelArguments", - init_kwargs: Dict[str, Any], -) -> None: - r""" - Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training) - """ - if getattr(config, "quantization_config", None): # ptq - if is_deepspeed_zero3_enabled(): - raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.") - - if model_args.quantization_device_map != "auto": - init_kwargs["device_map"] = {"": get_current_device()} - - quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) - quant_method = quantization_config.get("quant_method", "") - - if quant_method == QuantizationMethod.GPTQ: - require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") - quantization_config.pop("disable_exllama", None) # remove deprecated args - quantization_config["use_exllama"] = False # disable exllama - - if quant_method == QuantizationMethod.AWQ: - require_version("autoawq", "To fix: pip install autoawq") - - if quant_method == QuantizationMethod.AQLM: - require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0") - require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0") - quantization_config["bits"] = 2 - - quant_bits = quantization_config.get("bits", "?") - logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper())) - - elif model_args.export_quantization_bit is not None: # auto-gptq - require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0") - require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") - from accelerate.utils import get_max_memory - - if getattr(config, "model_type", None) == "chatglm": - raise ValueError("ChatGLM model is not supported.") - - init_kwargs["quantization_config"] = GPTQConfig( - bits=model_args.export_quantization_bit, - tokenizer=tokenizer, - dataset=_get_quantization_dataset(tokenizer, model_args), - ) - init_kwargs["device_map"] = "auto" - init_kwargs["max_memory"] = get_max_memory() - logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit)) - - elif model_args.quantization_bit is not None: # bnb - if model_args.quantization_bit == 8: - require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") - init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) - - elif model_args.quantization_bit == 4: - require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") - init_kwargs["quantization_config"] = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_compute_dtype=model_args.compute_dtype, - bnb_4bit_use_double_quant=model_args.double_quantization, - bnb_4bit_quant_type=model_args.quantization_type, - bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora - ) - - if is_deepspeed_zero3_enabled() or model_args.quantization_device_map == "auto": - if model_args.quantization_bit != 4: - raise ValueError("Only 4-bit quantized model can use auto device map.") - - require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0") - require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0") - require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0") - else: - init_kwargs["device_map"] = {"": get_current_device()} - - logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) - - -def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int): - embedding_dim = embed_weight.size(1) - avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True) - noise_weight = torch.empty_like(embed_weight[-num_new_tokens:]) - noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim))) - embed_weight[-num_new_tokens:] = avg_weight + noise_weight - - -def _resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None: - r""" - Resize token embeddings. - """ - if is_deepspeed_zero3_enabled(): - import deepspeed # type: ignore - - params = [model.get_input_embeddings().weight] - if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings: - params.append(model.get_output_embeddings().weight) - - context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0) - else: - context_maybe_zero3 = nullcontext() - - with context_maybe_zero3: - current_embedding_size = model.get_input_embeddings().weight.size(0) - - if len(tokenizer) > current_embedding_size: - if not isinstance(model.get_output_embeddings(), torch.nn.Linear): - logger.warning("Current model does not support resizing token embeddings.") - return - - model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64) - with context_maybe_zero3: - new_embedding_size = model.get_input_embeddings().weight.size(0) - num_new_tokens = new_embedding_size - current_embedding_size - _noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens) - _noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens) - - logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size)) - - -def _fp32_forward_post_hook( - module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor" -) -> "torch.Tensor": - return output.to(torch.float32) - - -def _prepare_model_for_training( - model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head" -) -> None: - r""" - Includes: - (1) cast the layernorm in fp32 - (2) make output embedding layer require grads - (3) add the upcasting of the lm_head in fp32 - Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72 - """ - if model_args.upcast_layernorm: - logger.info("Upcasting layernorm weights in float32.") - for name, param in model.named_parameters(): - if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES): - param.data = param.data.to(torch.float32) - - if not model_args.disable_gradient_checkpointing: - if not getattr(model, "supports_gradient_checkpointing", False): - logger.warning("Current model does not support gradient checkpointing.") - else: - # use_reentrant=False might increase VRAM usage (have not been empirically verified yet) - # According to: https://github.com/huggingface/transformers/issues/28339 - model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model) - model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) - setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled - logger.info("Gradient checkpointing enabled.") - - if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output: - logger.info("Upcasting lm_head outputs in float32.") - output_layer = getattr(model, output_layer_name) - if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32: - output_layer.register_forward_hook(_fp32_forward_post_hook) def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None: @@ -321,10 +42,10 @@ def patch_config( if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) - _configure_attn_implementation(config, model_args) - _configure_rope(config, model_args, is_trainable) - _configure_longlora(config, model_args, is_trainable) - _configure_quantization(config, tokenizer, model_args, init_kwargs) + configure_attn_implementation(config, model_args) + configure_rope(config, model_args, is_trainable) + configure_longlora(config, model_args, is_trainable) + configure_quantization(config, tokenizer, model_args, init_kwargs) if model_args.use_cache and not is_trainable: setattr(config, "use_cache", True) @@ -377,22 +98,14 @@ def patch_model( setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) if model_args.resize_vocab: - _resize_embedding_layer(model, tokenizer) + resize_embedding_layer(model, tokenizer) if is_trainable: - _prepare_model_for_training(model, model_args) + prepare_model_for_training(model, model_args) + add_z3_leaf_module(model) - if getattr(model.config, "model_type", None) == "mixtral": - from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock - - add_z3_leaf_module(model, MixtralSparseMoeBlock) - - if getattr(model.config, "model_type", None) == "qwen2moe": - from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock - - add_z3_leaf_module(model, Qwen2MoeSparseMoeBlock) - - _print_attn_implementation(model.config) + if not model_args.use_unsloth: + print_attn_implementation(model.config) try: model.add_model_tags(["llama-factory"]) diff --git a/src/llmtuner/extras/patches/__init__.py b/src/llmtuner/model/utils/__init__.py similarity index 100% rename from src/llmtuner/extras/patches/__init__.py rename to src/llmtuner/model/utils/__init__.py diff --git a/src/llmtuner/model/utils/attention.py b/src/llmtuner/model/utils/attention.py new file mode 100644 index 00000000..f4686489 --- /dev/null +++ b/src/llmtuner/model/utils/attention.py @@ -0,0 +1,55 @@ +from typing import TYPE_CHECKING + +from ...extras.logging import get_logger +from ...extras.packages import is_flash_attn2_available, is_sdpa_available + + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) + + +def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None: + if model_args.flash_attn == "auto": + return + + elif model_args.flash_attn == "off": + requested_attn_implementation = "eager" + + elif model_args.flash_attn == "sdpa": + if not is_sdpa_available(): + logger.warning("Torch>=2.1.1 is required for SDPA attention.") + return + + requested_attn_implementation = "sdpa" + elif model_args.flash_attn == "fa2": + if not is_flash_attn2_available(): + logger.warning("FlashAttention-2 is not installed.") + return + + requested_attn_implementation = "flash_attention_2" + else: + raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn)) + + if getattr(config, "model_type", None) == "internlm2": # special case for custom models + setattr(config, "attn_implementation", requested_attn_implementation) + else: + setattr(config, "_attn_implementation", requested_attn_implementation) + + +def print_attn_implementation(config: "PretrainedConfig") -> None: + if getattr(config, "model_type", None) == "internlm2": # special case for custom models + attn_implementation = getattr(config, "attn_implementation", None) + else: + attn_implementation = getattr(config, "_attn_implementation", None) + + if attn_implementation == "flash_attention_2": + logger.info("Using FlashAttention-2 for faster training and inference.") + elif attn_implementation == "sdpa": + logger.info("Using torch SDPA for faster training and inference.") + else: + logger.info("Using vanilla Attention implementation.") diff --git a/src/llmtuner/model/utils/checkpointing.py b/src/llmtuner/model/utils/checkpointing.py new file mode 100644 index 00000000..e0657be8 --- /dev/null +++ b/src/llmtuner/model/utils/checkpointing.py @@ -0,0 +1,94 @@ +import inspect +from functools import partial +from types import MethodType +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple + +import torch + +from ...extras.constants import LAYERNORM_NAMES +from ...extras.logging import get_logger + + +if TYPE_CHECKING: + from transformers import PreTrainedModel + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) + + +def _gradient_checkpointing_enable( + self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None +) -> None: + r""" + Activates gradient checkpointing for the current model. + + Modification of the original method to enable gradient checkpointing for block-wise optimizer. + """ + from torch.utils.checkpoint import checkpoint + + if not self.supports_gradient_checkpointing: + raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__)) + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {"use_reentrant": True} + + gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs) + + def custom_gradient_checkpointing_func(func, *args, **kwargs): + module: "torch.nn.Module" = func.__self__ + + if any(param.requires_grad for param in module.parameters()): + for arg in args: + if torch.is_tensor(arg) and torch.is_floating_point(arg): + arg.requires_grad_(True) + + return gradient_checkpointing_func(func, *args, **kwargs) + + if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format + self.apply(partial(self._set_gradient_checkpointing, value=True)) + self.enable_input_require_grads() + logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.") + else: # have already enabled input require gradients + self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func) + + +def _fp32_forward_post_hook( + module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor" +) -> "torch.Tensor": + return output.to(torch.float32) + + +def prepare_model_for_training( + model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head" +) -> None: + r""" + Includes: + (1) cast the layernorm in fp32 + (2) make output embedding layer require grads + (3) add the upcasting of the lm_head in fp32 + Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72 + """ + if model_args.upcast_layernorm: + logger.info("Upcasting layernorm weights in float32.") + for name, param in model.named_parameters(): + if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES): + param.data = param.data.to(torch.float32) + + if not model_args.disable_gradient_checkpointing: + if not getattr(model, "supports_gradient_checkpointing", False): + logger.warning("Current model does not support gradient checkpointing.") + else: + # use_reentrant=False might increase VRAM usage (have not been empirically verified yet) + # According to: https://github.com/huggingface/transformers/issues/28339 + model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model) + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) + setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled + logger.info("Gradient checkpointing enabled.") + + if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output: + logger.info("Upcasting lm_head outputs in float32.") + output_layer = getattr(model, output_layer_name) + if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32: + output_layer.register_forward_hook(_fp32_forward_post_hook) diff --git a/src/llmtuner/model/utils/embedding.py b/src/llmtuner/model/utils/embedding.py new file mode 100644 index 00000000..7759fc0f --- /dev/null +++ b/src/llmtuner/model/utils/embedding.py @@ -0,0 +1,56 @@ +import math +from contextlib import nullcontext +from typing import TYPE_CHECKING + +import torch +from transformers.integrations import is_deepspeed_zero3_enabled + +from ...extras.logging import get_logger + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizer + + +logger = get_logger(__name__) + + +def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int) -> None: + embedding_dim = embed_weight.size(1) + avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True) + noise_weight = torch.empty_like(embed_weight[-num_new_tokens:]) + noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim))) + embed_weight[-num_new_tokens:] = avg_weight + noise_weight + + +def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None: + r""" + Resize token embeddings. + """ + if is_deepspeed_zero3_enabled(): + import deepspeed # type: ignore + + params = [model.get_input_embeddings().weight] + if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings: + params.append(model.get_output_embeddings().weight) + + context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0) + else: + context_maybe_zero3 = nullcontext() + + with context_maybe_zero3: + current_embedding_size = model.get_input_embeddings().weight.size(0) + + if len(tokenizer) > current_embedding_size: + if not isinstance(model.get_output_embeddings(), torch.nn.Linear): + logger.warning("Current model does not support resizing token embeddings.") + return + + model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64) + with context_maybe_zero3: + new_embedding_size = model.get_input_embeddings().weight.size(0) + num_new_tokens = new_embedding_size - current_embedding_size + _noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens) + _noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens) + + logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size)) diff --git a/src/llmtuner/extras/patches/llama_patch.py b/src/llmtuner/model/utils/longlora.py similarity index 58% rename from src/llmtuner/extras/patches/llama_patch.py rename to src/llmtuner/model/utils/longlora.py index 6a90c41a..c3740a73 100644 --- a/src/llmtuner/extras/patches/llama_patch.py +++ b/src/llmtuner/model/utils/longlora.py @@ -1,5 +1,5 @@ import math -from typing import Optional, Tuple +from typing import TYPE_CHECKING, Optional, Tuple import torch import torch.nn as nn @@ -7,19 +7,28 @@ from transformers.models.llama.modeling_llama import ( Cache, LlamaAttention, LlamaFlashAttention2, + LlamaSdpaAttention, apply_rotary_pos_emb, repeat_kv, ) from transformers.utils import logging from transformers.utils.versions import require_version +from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN + + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + from ...hparams import ModelArguments + logger = logging.get_logger(__name__) # Modified from: -# https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/llama/modeling_llama.py -def llama_torch_attn_forward( +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py +def llama_attention_forward( self: "LlamaAttention", hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, @@ -39,10 +48,11 @@ def llama_torch_attn_forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - past_key_value = getattr(self, "past_key_value", past_key_value) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + past_key_value = getattr(self, "past_key_value", past_key_value) + if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -69,8 +79,9 @@ def llama_torch_attn_forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: - attn_weights = attn_weights + attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) @@ -97,8 +108,8 @@ def llama_torch_attn_forward( # Modified from: -# https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/llama/modeling_llama.py -def llama_flash_attn_forward( +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py +def llama_flash_attention_2_forward( self: "LlamaFlashAttention2", hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, @@ -117,7 +128,6 @@ def llama_flash_attn_forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - # FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -134,9 +144,10 @@ def llama_flash_attn_forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) - key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) - value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) + # FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) dropout_rate = self.attention_dropout if self.training else 0.0 @@ -192,7 +203,115 @@ def llama_flash_attn_forward( return attn_output, attn_weights, past_key_value -def apply_llama_patch() -> None: - require_version("transformers==4.39.3", "To fix: pip install transformers==4.39.3") - LlamaAttention.forward = llama_torch_attn_forward - LlamaFlashAttention2.forward = llama_flash_attn_forward +# Modified from: +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py +def llama_sdpa_attention_forward( + self: "LlamaSdpaAttention", + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional["Cache"] = None, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + logger.warning_once("SDPA does not support `output_attentions=True`. Falling back to the vanilla attention") + return llama_attention_forward( + self, + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, + **kwargs, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift + groupsz = int(q_len * getattr(self.config, "group_size_ratio")) + assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) + num_groups = q_len // groupsz + + def shift(state: torch.Tensor) -> torch.Tensor: + state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim) + state = torch.cat( + (state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)), + dim=2, + ) + return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2) + + query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) + if attention_mask is not None: + attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, :groupsz] + + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=causal_mask is None and q_len > 1, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift back + attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) + attn_output = torch.cat( + ( + attn_output[:, :, : self.num_heads // 2], + attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1), + ) + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +def _apply_llama_patch() -> None: + require_version("transformers==4.40.0", "To fix: pip install transformers==4.40.0") + LlamaAttention.forward = llama_attention_forward + LlamaFlashAttention2.forward = llama_flash_attention_2_forward + LlamaSdpaAttention.forward = llama_sdpa_attention_forward + + +def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: + if not is_trainable or not model_args.shift_attn: + return + + if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN: + setattr(config, "group_size_ratio", 0.25) + _apply_llama_patch() + logger.info("Using shift short attention with group_size_ratio=1/4.") + else: + logger.warning("Current model does not support shift short attention.") diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils/misc.py similarity index 61% rename from src/llmtuner/model/utils.py rename to src/llmtuner/model/utils/misc.py index 51dbca8e..57e772f7 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils/misc.py @@ -1,51 +1,23 @@ -import inspect -from enum import Enum, unique -from functools import partial -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List import torch from transformers import PreTrainedModel -from transformers.integrations import is_deepspeed_zero3_enabled from transformers.utils import cached_file -from transformers.utils.versions import require_version -from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME -from ..extras.logging import get_logger +from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME +from ...extras.logging import get_logger +from .quantization import QuantizationMethod if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedTokenizer - from ..hparams import ModelArguments + from ...hparams import ModelArguments logger = get_logger(__name__) -@unique -class QuantizationMethod(str, Enum): - r""" - Borrowed from `transformers.utils.quantization_config.QuantizationMethod`. - """ - - BITS_AND_BYTES = "bitsandbytes" - GPTQ = "gptq" - AWQ = "awq" - AQLM = "aqlm" - QUANTO = "quanto" - - -def add_z3_leaf_module(model: "PreTrainedModel", module: "torch.nn.Module") -> None: - r""" - Sets module as a leaf module to skip partitioning in deepspeed zero3. - """ - if is_deepspeed_zero3_enabled(): - require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0") - from deepspeed.utils import set_z3_leaf_modules # type: ignore - - set_z3_leaf_modules(model, [module]) - - def find_all_linear_modules(model: "PreTrainedModel") -> List[str]: r""" Finds all available modules to apply lora or galore. @@ -102,42 +74,6 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n return module_names -def gradient_checkpointing_enable( - self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None -) -> None: - r""" - Activates gradient checkpointing for the current model. - - Modification of the original method to enable gradient checkpointing for block-wise optimizer. - """ - from torch.utils.checkpoint import checkpoint - - if not self.supports_gradient_checkpointing: - raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__)) - - if gradient_checkpointing_kwargs is None: - gradient_checkpointing_kwargs = {"use_reentrant": True} - - gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs) - - def custom_gradient_checkpointing_func(func, *args, **kwargs): - module: "torch.nn.Module" = func.__self__ - - if any(param.requires_grad for param in module.parameters()): - for arg in args: - if torch.is_tensor(arg) and torch.is_floating_point(arg): - arg.requires_grad_(True) - - return gradient_checkpointing_func(func, *args, **kwargs) - - if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format - self.apply(partial(self._set_gradient_checkpointing, value=True)) - self.enable_input_require_grads() - logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.") - else: # have already enabled input require gradients - self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func) - - def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]: r""" Loads value head parameters from Hugging Face Hub or local disk. diff --git a/src/llmtuner/model/utils/moe.py b/src/llmtuner/model/utils/moe.py new file mode 100644 index 00000000..020a8f55 --- /dev/null +++ b/src/llmtuner/model/utils/moe.py @@ -0,0 +1,39 @@ +from typing import TYPE_CHECKING + +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.utils.versions import require_version + + +if TYPE_CHECKING: + from transformers import PreTrainedModel + + +def add_z3_leaf_module(model: "PreTrainedModel") -> None: + r""" + Sets module as a leaf module to skip partitioning in deepspeed zero3. + """ + if not is_deepspeed_zero3_enabled(): + return + + require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0") + from deepspeed.utils import set_z3_leaf_modules # type: ignore + + if getattr(model.config, "model_type", None) == "mixtral": + from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + + set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) + + if getattr(model.config, "model_type", None) == "qwen2moe": + from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock + + set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock]) + + if getattr(model.config, "model_type", None) == "jamba": + from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock + + set_z3_leaf_modules(model, [JambaSparseMoeBlock]) + + if getattr(model.config, "model_type", None) == "dbrx": + from transformers.models.dbrx.modeling_dbrx import DbrxFFN + + set_z3_leaf_modules(model, [DbrxFFN]) diff --git a/src/llmtuner/model/utils/quantization.py b/src/llmtuner/model/utils/quantization.py new file mode 100644 index 00000000..3cf159c1 --- /dev/null +++ b/src/llmtuner/model/utils/quantization.py @@ -0,0 +1,146 @@ +import os +import random +from enum import Enum, unique +from typing import TYPE_CHECKING, Any, Dict, List + +import torch +from datasets import load_dataset +from transformers import BitsAndBytesConfig, GPTQConfig +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.utils.versions import require_version + +from ...extras.constants import FILEEXT2TYPE +from ...extras.logging import get_logger +from ...extras.misc import get_current_device + + +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedTokenizer + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) + + +@unique +class QuantizationMethod(str, Enum): + r""" + Borrowed from `transformers.utils.quantization_config.QuantizationMethod`. + """ + + BITS_AND_BYTES = "bitsandbytes" + GPTQ = "gptq" + AWQ = "awq" + AQLM = "aqlm" + QUANTO = "quanto" + + +def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]: + r""" + Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133 + TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600 + """ + if os.path.isfile(model_args.export_quantization_dataset): + data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None) + data_files = model_args.export_quantization_dataset + else: + data_path = model_args.export_quantization_dataset + data_files = None + + dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir) + maxlen = model_args.export_quantization_maxlen + + samples = [] + for _ in range(model_args.export_quantization_nsamples): + while True: + sample_idx = random.randint(0, len(dataset) - 1) + sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt") + if sample["input_ids"].size(1) >= maxlen: + break # TODO: fix large maxlen + + word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1) + input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen] + samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True)) + + return samples + + +def configure_quantization( + config: "PretrainedConfig", + tokenizer: "PreTrainedTokenizer", + model_args: "ModelArguments", + init_kwargs: Dict[str, Any], +) -> None: + r""" + Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training) + """ + if getattr(config, "quantization_config", None): # ptq + if is_deepspeed_zero3_enabled(): + raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.") + + if model_args.quantization_device_map != "auto": + init_kwargs["device_map"] = {"": get_current_device()} + + quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) + quant_method = quantization_config.get("quant_method", "") + + if quant_method == QuantizationMethod.GPTQ: + require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") + quantization_config.pop("disable_exllama", None) # remove deprecated args + quantization_config["use_exllama"] = False # disable exllama + + if quant_method == QuantizationMethod.AWQ: + require_version("autoawq", "To fix: pip install autoawq") + + if quant_method == QuantizationMethod.AQLM: + require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0") + require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0") + quantization_config["bits"] = 2 + + quant_bits = quantization_config.get("bits", "?") + logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper())) + + elif model_args.export_quantization_bit is not None: # auto-gptq + require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0") + require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") + from accelerate.utils import get_max_memory + + if getattr(config, "model_type", None) == "chatglm": + raise ValueError("ChatGLM model is not supported.") + + init_kwargs["quantization_config"] = GPTQConfig( + bits=model_args.export_quantization_bit, + tokenizer=tokenizer, + dataset=_get_quantization_dataset(tokenizer, model_args), + ) + init_kwargs["device_map"] = "auto" + init_kwargs["max_memory"] = get_max_memory() + logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit)) + + elif model_args.quantization_bit is not None: # bnb + if model_args.quantization_bit == 8: + require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") + init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) + + elif model_args.quantization_bit == 4: + require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") + init_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=model_args.compute_dtype, + bnb_4bit_use_double_quant=model_args.double_quantization, + bnb_4bit_quant_type=model_args.quantization_type, + bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora + ) + + if is_deepspeed_zero3_enabled() or model_args.quantization_device_map == "auto": + if model_args.quantization_bit != 4: + raise ValueError("Only 4-bit quantized model can use auto device map.") + + require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0") + require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0") + require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0") + else: + init_kwargs["device_map"] = {"": get_current_device()} + + logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) diff --git a/src/llmtuner/model/utils/rope.py b/src/llmtuner/model/utils/rope.py new file mode 100644 index 00000000..2a4cce7a --- /dev/null +++ b/src/llmtuner/model/utils/rope.py @@ -0,0 +1,43 @@ +import math +from typing import TYPE_CHECKING + +from ...extras.logging import get_logger + + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) + + +def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: + if model_args.rope_scaling is None: + return + + if not hasattr(config, "rope_scaling"): + logger.warning("Current model does not support RoPE scaling.") + return + + if is_trainable: + if model_args.rope_scaling == "dynamic": + logger.warning( + "Dynamic NTK scaling may not work well with fine-tuning. " + "See: https://github.com/huggingface/transformers/pull/24653" + ) + + current_max_length = getattr(config, "max_position_embeddings", None) + if current_max_length and model_args.model_max_length > current_max_length: + scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length)) + else: + logger.warning("Input length is smaller than max length. Consider increase input length.") + scaling_factor = 1.0 + else: + scaling_factor = 2.0 + + setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor}) + logger.info( + "Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor) + )