mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-28 16:52:51 +08:00
refactor patcher
Former-commit-id: aa2b79eb23c60825e6601b0b8cc6b59e3f566b2d
This commit is contained in:
parent
80c8586534
commit
8465e54d38
@ -47,6 +47,8 @@ TRAINING_STAGES = {
|
|||||||
|
|
||||||
STAGES_USE_PAIR_DATA = ["rm", "dpo", "orpo"]
|
STAGES_USE_PAIR_DATA = ["rm", "dpo", "orpo"]
|
||||||
|
|
||||||
|
SUPPORTED_CLASS_FOR_S2ATTN = ["llama"]
|
||||||
|
|
||||||
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
||||||
|
|
||||||
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
|
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from .loader import load_config, load_model, load_tokenizer
|
from .loader import load_config, load_model, load_tokenizer
|
||||||
from .utils import find_all_linear_modules, load_valuehead_params
|
from .utils.misc import find_all_linear_modules, load_valuehead_params
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -5,7 +5,8 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
|
|||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from .utils import QuantizationMethod, find_all_linear_modules, find_expanded_modules
|
from .utils.misc import find_all_linear_modules, find_expanded_modules
|
||||||
|
from .utils.quantization import QuantizationMethod
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -8,7 +8,7 @@ from ..extras.logging import get_logger
|
|||||||
from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
|
from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
|
||||||
from .adapter import init_adapter
|
from .adapter import init_adapter
|
||||||
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
|
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
|
||||||
from .utils import load_valuehead_params, register_autoclass
|
from .utils.misc import load_valuehead_params, register_autoclass
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -1,23 +1,20 @@
|
|||||||
import math
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
from contextlib import nullcontext
|
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
from typing import TYPE_CHECKING, Any, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
|
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from transformers.utils.versions import require_version
|
|
||||||
|
|
||||||
from ..extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.misc import get_current_device, infer_optim_dtype
|
from ..extras.misc import infer_optim_dtype
|
||||||
from ..extras.packages import is_flash_attn2_available, is_sdpa_available
|
from .utils.attention import configure_attn_implementation, print_attn_implementation
|
||||||
from ..extras.patches.llama_patch import apply_llama_patch
|
from .utils.checkpointing import prepare_model_for_training
|
||||||
from .utils import QuantizationMethod, add_z3_leaf_module, gradient_checkpointing_enable
|
from .utils.embedding import resize_embedding_layer
|
||||||
|
from .utils.longlora import configure_longlora
|
||||||
|
from .utils.moe import add_z3_leaf_module
|
||||||
|
from .utils.quantization import configure_quantization
|
||||||
|
from .utils.rope import configure_rope
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -28,282 +25,6 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
SUPPORTED_CLASS_FOR_S2ATTN = ["llama"]
|
|
||||||
|
|
||||||
|
|
||||||
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
|
|
||||||
r"""
|
|
||||||
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
|
|
||||||
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
|
|
||||||
"""
|
|
||||||
if os.path.isfile(model_args.export_quantization_dataset):
|
|
||||||
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
|
|
||||||
data_files = model_args.export_quantization_dataset
|
|
||||||
else:
|
|
||||||
data_path = model_args.export_quantization_dataset
|
|
||||||
data_files = None
|
|
||||||
|
|
||||||
dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
|
|
||||||
maxlen = model_args.export_quantization_maxlen
|
|
||||||
|
|
||||||
samples = []
|
|
||||||
for _ in range(model_args.export_quantization_nsamples):
|
|
||||||
while True:
|
|
||||||
sample_idx = random.randint(0, len(dataset) - 1)
|
|
||||||
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
|
|
||||||
if sample["input_ids"].size(1) >= maxlen:
|
|
||||||
break # TODO: fix large maxlen
|
|
||||||
|
|
||||||
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
|
|
||||||
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
|
|
||||||
samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
|
|
||||||
|
|
||||||
return samples
|
|
||||||
|
|
||||||
|
|
||||||
def _configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
|
|
||||||
if model_args.flash_attn == "auto":
|
|
||||||
return
|
|
||||||
|
|
||||||
elif model_args.flash_attn == "off":
|
|
||||||
requested_attn_implementation = "eager"
|
|
||||||
|
|
||||||
elif model_args.flash_attn == "sdpa":
|
|
||||||
if not is_sdpa_available():
|
|
||||||
logger.warning("Torch>=2.1.1 is required for SDPA attention.")
|
|
||||||
return
|
|
||||||
|
|
||||||
requested_attn_implementation = "sdpa"
|
|
||||||
elif model_args.flash_attn == "fa2":
|
|
||||||
if not is_flash_attn2_available():
|
|
||||||
logger.warning("FlashAttention-2 is not installed.")
|
|
||||||
return
|
|
||||||
|
|
||||||
requested_attn_implementation = "flash_attention_2"
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn))
|
|
||||||
|
|
||||||
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
|
|
||||||
setattr(config, "attn_implementation", requested_attn_implementation)
|
|
||||||
else:
|
|
||||||
setattr(config, "_attn_implementation", requested_attn_implementation)
|
|
||||||
|
|
||||||
|
|
||||||
def _print_attn_implementation(config: "PretrainedConfig") -> None:
|
|
||||||
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
|
|
||||||
attn_implementation = getattr(config, "attn_implementation", None)
|
|
||||||
else:
|
|
||||||
attn_implementation = getattr(config, "_attn_implementation", None)
|
|
||||||
|
|
||||||
if attn_implementation == "flash_attention_2":
|
|
||||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
|
||||||
elif attn_implementation == "sdpa":
|
|
||||||
logger.info("Using torch SDPA for faster training and inference.")
|
|
||||||
else:
|
|
||||||
logger.info("Using vanilla Attention implementation.")
|
|
||||||
|
|
||||||
|
|
||||||
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
|
||||||
if model_args.rope_scaling is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
if not hasattr(config, "rope_scaling"):
|
|
||||||
logger.warning("Current model does not support RoPE scaling.")
|
|
||||||
return
|
|
||||||
|
|
||||||
if is_trainable:
|
|
||||||
if model_args.rope_scaling == "dynamic":
|
|
||||||
logger.warning(
|
|
||||||
"Dynamic NTK scaling may not work well with fine-tuning. "
|
|
||||||
"See: https://github.com/huggingface/transformers/pull/24653"
|
|
||||||
)
|
|
||||||
|
|
||||||
current_max_length = getattr(config, "max_position_embeddings", None)
|
|
||||||
if current_max_length and model_args.model_max_length > current_max_length:
|
|
||||||
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
|
||||||
else:
|
|
||||||
logger.warning("Input length is smaller than max length. Consider increase input length.")
|
|
||||||
scaling_factor = 1.0
|
|
||||||
else:
|
|
||||||
scaling_factor = 2.0
|
|
||||||
|
|
||||||
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
|
||||||
logger.info(
|
|
||||||
"Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
|
||||||
if not is_trainable or not model_args.shift_attn:
|
|
||||||
return
|
|
||||||
|
|
||||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
|
||||||
setattr(config, "group_size_ratio", 0.25)
|
|
||||||
apply_llama_patch()
|
|
||||||
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
|
||||||
else:
|
|
||||||
logger.warning("Current model does not support shift short attention.")
|
|
||||||
|
|
||||||
|
|
||||||
def _configure_quantization(
|
|
||||||
config: "PretrainedConfig",
|
|
||||||
tokenizer: "PreTrainedTokenizer",
|
|
||||||
model_args: "ModelArguments",
|
|
||||||
init_kwargs: Dict[str, Any],
|
|
||||||
) -> None:
|
|
||||||
r"""
|
|
||||||
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
|
||||||
"""
|
|
||||||
if getattr(config, "quantization_config", None): # ptq
|
|
||||||
if is_deepspeed_zero3_enabled():
|
|
||||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.")
|
|
||||||
|
|
||||||
if model_args.quantization_device_map != "auto":
|
|
||||||
init_kwargs["device_map"] = {"": get_current_device()}
|
|
||||||
|
|
||||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
|
||||||
quant_method = quantization_config.get("quant_method", "")
|
|
||||||
|
|
||||||
if quant_method == QuantizationMethod.GPTQ:
|
|
||||||
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
|
||||||
quantization_config.pop("disable_exllama", None) # remove deprecated args
|
|
||||||
quantization_config["use_exllama"] = False # disable exllama
|
|
||||||
|
|
||||||
if quant_method == QuantizationMethod.AWQ:
|
|
||||||
require_version("autoawq", "To fix: pip install autoawq")
|
|
||||||
|
|
||||||
if quant_method == QuantizationMethod.AQLM:
|
|
||||||
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
|
|
||||||
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
|
|
||||||
quantization_config["bits"] = 2
|
|
||||||
|
|
||||||
quant_bits = quantization_config.get("bits", "?")
|
|
||||||
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
|
|
||||||
|
|
||||||
elif model_args.export_quantization_bit is not None: # auto-gptq
|
|
||||||
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
|
||||||
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
|
||||||
from accelerate.utils import get_max_memory
|
|
||||||
|
|
||||||
if getattr(config, "model_type", None) == "chatglm":
|
|
||||||
raise ValueError("ChatGLM model is not supported.")
|
|
||||||
|
|
||||||
init_kwargs["quantization_config"] = GPTQConfig(
|
|
||||||
bits=model_args.export_quantization_bit,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
dataset=_get_quantization_dataset(tokenizer, model_args),
|
|
||||||
)
|
|
||||||
init_kwargs["device_map"] = "auto"
|
|
||||||
init_kwargs["max_memory"] = get_max_memory()
|
|
||||||
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
|
|
||||||
|
|
||||||
elif model_args.quantization_bit is not None: # bnb
|
|
||||||
if model_args.quantization_bit == 8:
|
|
||||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
|
||||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
|
||||||
|
|
||||||
elif model_args.quantization_bit == 4:
|
|
||||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
|
||||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(
|
|
||||||
load_in_4bit=True,
|
|
||||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
|
||||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
|
||||||
bnb_4bit_quant_type=model_args.quantization_type,
|
|
||||||
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_deepspeed_zero3_enabled() or model_args.quantization_device_map == "auto":
|
|
||||||
if model_args.quantization_bit != 4:
|
|
||||||
raise ValueError("Only 4-bit quantized model can use auto device map.")
|
|
||||||
|
|
||||||
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
|
|
||||||
require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0")
|
|
||||||
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
|
|
||||||
else:
|
|
||||||
init_kwargs["device_map"] = {"": get_current_device()}
|
|
||||||
|
|
||||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
|
||||||
|
|
||||||
|
|
||||||
def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int):
|
|
||||||
embedding_dim = embed_weight.size(1)
|
|
||||||
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
|
|
||||||
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
|
|
||||||
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
|
|
||||||
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
|
|
||||||
|
|
||||||
|
|
||||||
def _resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
|
|
||||||
r"""
|
|
||||||
Resize token embeddings.
|
|
||||||
"""
|
|
||||||
if is_deepspeed_zero3_enabled():
|
|
||||||
import deepspeed # type: ignore
|
|
||||||
|
|
||||||
params = [model.get_input_embeddings().weight]
|
|
||||||
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
|
|
||||||
params.append(model.get_output_embeddings().weight)
|
|
||||||
|
|
||||||
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
|
|
||||||
else:
|
|
||||||
context_maybe_zero3 = nullcontext()
|
|
||||||
|
|
||||||
with context_maybe_zero3:
|
|
||||||
current_embedding_size = model.get_input_embeddings().weight.size(0)
|
|
||||||
|
|
||||||
if len(tokenizer) > current_embedding_size:
|
|
||||||
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
|
|
||||||
logger.warning("Current model does not support resizing token embeddings.")
|
|
||||||
return
|
|
||||||
|
|
||||||
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
|
|
||||||
with context_maybe_zero3:
|
|
||||||
new_embedding_size = model.get_input_embeddings().weight.size(0)
|
|
||||||
num_new_tokens = new_embedding_size - current_embedding_size
|
|
||||||
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
|
|
||||||
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
|
|
||||||
|
|
||||||
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
|
|
||||||
|
|
||||||
|
|
||||||
def _fp32_forward_post_hook(
|
|
||||||
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
|
|
||||||
) -> "torch.Tensor":
|
|
||||||
return output.to(torch.float32)
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_model_for_training(
|
|
||||||
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head"
|
|
||||||
) -> None:
|
|
||||||
r"""
|
|
||||||
Includes:
|
|
||||||
(1) cast the layernorm in fp32
|
|
||||||
(2) make output embedding layer require grads
|
|
||||||
(3) add the upcasting of the lm_head in fp32
|
|
||||||
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
|
|
||||||
"""
|
|
||||||
if model_args.upcast_layernorm:
|
|
||||||
logger.info("Upcasting layernorm weights in float32.")
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
|
|
||||||
param.data = param.data.to(torch.float32)
|
|
||||||
|
|
||||||
if not model_args.disable_gradient_checkpointing:
|
|
||||||
if not getattr(model, "supports_gradient_checkpointing", False):
|
|
||||||
logger.warning("Current model does not support gradient checkpointing.")
|
|
||||||
else:
|
|
||||||
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
|
|
||||||
# According to: https://github.com/huggingface/transformers/issues/28339
|
|
||||||
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
|
|
||||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
|
||||||
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
|
||||||
logger.info("Gradient checkpointing enabled.")
|
|
||||||
|
|
||||||
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
|
|
||||||
logger.info("Upcasting lm_head outputs in float32.")
|
|
||||||
output_layer = getattr(model, output_layer_name)
|
|
||||||
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
|
|
||||||
output_layer.register_forward_hook(_fp32_forward_post_hook)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
|
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
|
||||||
@ -321,10 +42,10 @@ def patch_config(
|
|||||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||||
|
|
||||||
_configure_attn_implementation(config, model_args)
|
configure_attn_implementation(config, model_args)
|
||||||
_configure_rope(config, model_args, is_trainable)
|
configure_rope(config, model_args, is_trainable)
|
||||||
_configure_longlora(config, model_args, is_trainable)
|
configure_longlora(config, model_args, is_trainable)
|
||||||
_configure_quantization(config, tokenizer, model_args, init_kwargs)
|
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
||||||
|
|
||||||
if model_args.use_cache and not is_trainable:
|
if model_args.use_cache and not is_trainable:
|
||||||
setattr(config, "use_cache", True)
|
setattr(config, "use_cache", True)
|
||||||
@ -377,22 +98,14 @@ def patch_model(
|
|||||||
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
|
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
|
||||||
|
|
||||||
if model_args.resize_vocab:
|
if model_args.resize_vocab:
|
||||||
_resize_embedding_layer(model, tokenizer)
|
resize_embedding_layer(model, tokenizer)
|
||||||
|
|
||||||
if is_trainable:
|
if is_trainable:
|
||||||
_prepare_model_for_training(model, model_args)
|
prepare_model_for_training(model, model_args)
|
||||||
|
add_z3_leaf_module(model)
|
||||||
|
|
||||||
if getattr(model.config, "model_type", None) == "mixtral":
|
if not model_args.use_unsloth:
|
||||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
print_attn_implementation(model.config)
|
||||||
|
|
||||||
add_z3_leaf_module(model, MixtralSparseMoeBlock)
|
|
||||||
|
|
||||||
if getattr(model.config, "model_type", None) == "qwen2moe":
|
|
||||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
|
|
||||||
|
|
||||||
add_z3_leaf_module(model, Qwen2MoeSparseMoeBlock)
|
|
||||||
|
|
||||||
_print_attn_implementation(model.config)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model.add_model_tags(["llama-factory"])
|
model.add_model_tags(["llama-factory"])
|
||||||
|
55
src/llmtuner/model/utils/attention.py
Normal file
55
src/llmtuner/model/utils/attention.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...extras.logging import get_logger
|
||||||
|
from ...extras.packages import is_flash_attn2_available, is_sdpa_available
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
|
||||||
|
if model_args.flash_attn == "auto":
|
||||||
|
return
|
||||||
|
|
||||||
|
elif model_args.flash_attn == "off":
|
||||||
|
requested_attn_implementation = "eager"
|
||||||
|
|
||||||
|
elif model_args.flash_attn == "sdpa":
|
||||||
|
if not is_sdpa_available():
|
||||||
|
logger.warning("Torch>=2.1.1 is required for SDPA attention.")
|
||||||
|
return
|
||||||
|
|
||||||
|
requested_attn_implementation = "sdpa"
|
||||||
|
elif model_args.flash_attn == "fa2":
|
||||||
|
if not is_flash_attn2_available():
|
||||||
|
logger.warning("FlashAttention-2 is not installed.")
|
||||||
|
return
|
||||||
|
|
||||||
|
requested_attn_implementation = "flash_attention_2"
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn))
|
||||||
|
|
||||||
|
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
|
||||||
|
setattr(config, "attn_implementation", requested_attn_implementation)
|
||||||
|
else:
|
||||||
|
setattr(config, "_attn_implementation", requested_attn_implementation)
|
||||||
|
|
||||||
|
|
||||||
|
def print_attn_implementation(config: "PretrainedConfig") -> None:
|
||||||
|
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
|
||||||
|
attn_implementation = getattr(config, "attn_implementation", None)
|
||||||
|
else:
|
||||||
|
attn_implementation = getattr(config, "_attn_implementation", None)
|
||||||
|
|
||||||
|
if attn_implementation == "flash_attention_2":
|
||||||
|
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||||
|
elif attn_implementation == "sdpa":
|
||||||
|
logger.info("Using torch SDPA for faster training and inference.")
|
||||||
|
else:
|
||||||
|
logger.info("Using vanilla Attention implementation.")
|
94
src/llmtuner/model/utils/checkpointing.py
Normal file
94
src/llmtuner/model/utils/checkpointing.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
import inspect
|
||||||
|
from functools import partial
|
||||||
|
from types import MethodType
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...extras.constants import LAYERNORM_NAMES
|
||||||
|
from ...extras.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _gradient_checkpointing_enable(
|
||||||
|
self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
) -> None:
|
||||||
|
r"""
|
||||||
|
Activates gradient checkpointing for the current model.
|
||||||
|
|
||||||
|
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
|
||||||
|
"""
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
|
if not self.supports_gradient_checkpointing:
|
||||||
|
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))
|
||||||
|
|
||||||
|
if gradient_checkpointing_kwargs is None:
|
||||||
|
gradient_checkpointing_kwargs = {"use_reentrant": True}
|
||||||
|
|
||||||
|
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
|
||||||
|
|
||||||
|
def custom_gradient_checkpointing_func(func, *args, **kwargs):
|
||||||
|
module: "torch.nn.Module" = func.__self__
|
||||||
|
|
||||||
|
if any(param.requires_grad for param in module.parameters()):
|
||||||
|
for arg in args:
|
||||||
|
if torch.is_tensor(arg) and torch.is_floating_point(arg):
|
||||||
|
arg.requires_grad_(True)
|
||||||
|
|
||||||
|
return gradient_checkpointing_func(func, *args, **kwargs)
|
||||||
|
|
||||||
|
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
|
||||||
|
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
||||||
|
self.enable_input_require_grads()
|
||||||
|
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
|
||||||
|
else: # have already enabled input require gradients
|
||||||
|
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
|
||||||
|
|
||||||
|
|
||||||
|
def _fp32_forward_post_hook(
|
||||||
|
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
|
||||||
|
) -> "torch.Tensor":
|
||||||
|
return output.to(torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_model_for_training(
|
||||||
|
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head"
|
||||||
|
) -> None:
|
||||||
|
r"""
|
||||||
|
Includes:
|
||||||
|
(1) cast the layernorm in fp32
|
||||||
|
(2) make output embedding layer require grads
|
||||||
|
(3) add the upcasting of the lm_head in fp32
|
||||||
|
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
|
||||||
|
"""
|
||||||
|
if model_args.upcast_layernorm:
|
||||||
|
logger.info("Upcasting layernorm weights in float32.")
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
|
||||||
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
|
if not model_args.disable_gradient_checkpointing:
|
||||||
|
if not getattr(model, "supports_gradient_checkpointing", False):
|
||||||
|
logger.warning("Current model does not support gradient checkpointing.")
|
||||||
|
else:
|
||||||
|
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
|
||||||
|
# According to: https://github.com/huggingface/transformers/issues/28339
|
||||||
|
model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model)
|
||||||
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
||||||
|
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
||||||
|
logger.info("Gradient checkpointing enabled.")
|
||||||
|
|
||||||
|
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
|
||||||
|
logger.info("Upcasting lm_head outputs in float32.")
|
||||||
|
output_layer = getattr(model, output_layer_name)
|
||||||
|
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
|
||||||
|
output_layer.register_forward_hook(_fp32_forward_post_hook)
|
56
src/llmtuner/model/utils/embedding.py
Normal file
56
src/llmtuner/model/utils/embedding.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import math
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
|
from ...extras.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int) -> None:
|
||||||
|
embedding_dim = embed_weight.size(1)
|
||||||
|
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
|
||||||
|
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
|
||||||
|
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
|
||||||
|
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
|
||||||
|
|
||||||
|
|
||||||
|
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
|
||||||
|
r"""
|
||||||
|
Resize token embeddings.
|
||||||
|
"""
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
import deepspeed # type: ignore
|
||||||
|
|
||||||
|
params = [model.get_input_embeddings().weight]
|
||||||
|
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
|
||||||
|
params.append(model.get_output_embeddings().weight)
|
||||||
|
|
||||||
|
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
|
||||||
|
else:
|
||||||
|
context_maybe_zero3 = nullcontext()
|
||||||
|
|
||||||
|
with context_maybe_zero3:
|
||||||
|
current_embedding_size = model.get_input_embeddings().weight.size(0)
|
||||||
|
|
||||||
|
if len(tokenizer) > current_embedding_size:
|
||||||
|
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
|
||||||
|
logger.warning("Current model does not support resizing token embeddings.")
|
||||||
|
return
|
||||||
|
|
||||||
|
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
|
||||||
|
with context_maybe_zero3:
|
||||||
|
new_embedding_size = model.get_input_embeddings().weight.size(0)
|
||||||
|
num_new_tokens = new_embedding_size - current_embedding_size
|
||||||
|
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
|
||||||
|
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
|
||||||
|
|
||||||
|
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
|
@ -1,5 +1,5 @@
|
|||||||
import math
|
import math
|
||||||
from typing import Optional, Tuple
|
from typing import TYPE_CHECKING, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -7,19 +7,28 @@ from transformers.models.llama.modeling_llama import (
|
|||||||
Cache,
|
Cache,
|
||||||
LlamaAttention,
|
LlamaAttention,
|
||||||
LlamaFlashAttention2,
|
LlamaFlashAttention2,
|
||||||
|
LlamaSdpaAttention,
|
||||||
apply_rotary_pos_emb,
|
apply_rotary_pos_emb,
|
||||||
repeat_kv,
|
repeat_kv,
|
||||||
)
|
)
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Modified from:
|
# Modified from:
|
||||||
# https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/llama/modeling_llama.py
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
|
||||||
def llama_torch_attn_forward(
|
def llama_attention_forward(
|
||||||
self: "LlamaAttention",
|
self: "LlamaAttention",
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
@ -39,10 +48,11 @@ def llama_torch_attn_forward(
|
|||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
past_key_value = getattr(self, "past_key_value", past_key_value)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
@ -69,8 +79,9 @@ def llama_torch_attn_forward(
|
|||||||
|
|
||||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None: # no matter the length, we just slice it
|
||||||
attn_weights = attn_weights + attention_mask
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||||
|
attn_weights = attn_weights + causal_mask
|
||||||
|
|
||||||
# upcast attention to fp32
|
# upcast attention to fp32
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
@ -97,8 +108,8 @@ def llama_torch_attn_forward(
|
|||||||
|
|
||||||
|
|
||||||
# Modified from:
|
# Modified from:
|
||||||
# https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/llama/modeling_llama.py
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
|
||||||
def llama_flash_attn_forward(
|
def llama_flash_attention_2_forward(
|
||||||
self: "LlamaFlashAttention2",
|
self: "LlamaFlashAttention2",
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
@ -117,7 +128,6 @@ def llama_flash_attn_forward(
|
|||||||
key_states = self.k_proj(hidden_states)
|
key_states = self.k_proj(hidden_states)
|
||||||
value_states = self.v_proj(hidden_states)
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
@ -134,9 +144,10 @@ def llama_flash_attn_forward(
|
|||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
|
||||||
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
query_states = query_states.transpose(1, 2)
|
||||||
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||||
|
|
||||||
@ -192,7 +203,115 @@ def llama_flash_attn_forward(
|
|||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
def apply_llama_patch() -> None:
|
# Modified from:
|
||||||
require_version("transformers==4.39.3", "To fix: pip install transformers==4.39.3")
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
|
||||||
LlamaAttention.forward = llama_torch_attn_forward
|
def llama_sdpa_attention_forward(
|
||||||
LlamaFlashAttention2.forward = llama_flash_attn_forward
|
self: "LlamaSdpaAttention",
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional["Cache"] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
if output_attentions:
|
||||||
|
logger.warning_once("SDPA does not support `output_attentions=True`. Falling back to the vanilla attention")
|
||||||
|
return llama_attention_forward(
|
||||||
|
self,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||||
|
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||||
|
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||||
|
num_groups = q_len // groupsz
|
||||||
|
|
||||||
|
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||||
|
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
||||||
|
state = torch.cat(
|
||||||
|
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
|
||||||
|
|
||||||
|
causal_mask = attention_mask
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask[:, :, :, :groupsz]
|
||||||
|
|
||||||
|
query_states = query_states.contiguous()
|
||||||
|
key_states = key_states.contiguous()
|
||||||
|
value_states = value_states.contiguous()
|
||||||
|
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_mask=causal_mask,
|
||||||
|
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||||
|
is_causal=causal_mask is None and q_len > 1,
|
||||||
|
)
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||||
|
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
|
attn_output = torch.cat(
|
||||||
|
(
|
||||||
|
attn_output[:, :, : self.num_heads // 2],
|
||||||
|
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_llama_patch() -> None:
|
||||||
|
require_version("transformers==4.40.0", "To fix: pip install transformers==4.40.0")
|
||||||
|
LlamaAttention.forward = llama_attention_forward
|
||||||
|
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
||||||
|
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
||||||
|
|
||||||
|
|
||||||
|
def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||||
|
if not is_trainable or not model_args.shift_attn:
|
||||||
|
return
|
||||||
|
|
||||||
|
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
||||||
|
setattr(config, "group_size_ratio", 0.25)
|
||||||
|
_apply_llama_patch()
|
||||||
|
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
||||||
|
else:
|
||||||
|
logger.warning("Current model does not support shift short attention.")
|
@ -1,51 +1,23 @@
|
|||||||
import inspect
|
from typing import TYPE_CHECKING, Dict, List
|
||||||
from enum import Enum, unique
|
|
||||||
from functools import partial
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
|
||||||
from transformers.utils import cached_file
|
from transformers.utils import cached_file
|
||||||
from transformers.utils.versions import require_version
|
|
||||||
|
|
||||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from ..extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
|
from .quantization import QuantizationMethod
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PretrainedConfig, PreTrainedTokenizer
|
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||||
|
|
||||||
from ..hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@unique
|
|
||||||
class QuantizationMethod(str, Enum):
|
|
||||||
r"""
|
|
||||||
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
BITS_AND_BYTES = "bitsandbytes"
|
|
||||||
GPTQ = "gptq"
|
|
||||||
AWQ = "awq"
|
|
||||||
AQLM = "aqlm"
|
|
||||||
QUANTO = "quanto"
|
|
||||||
|
|
||||||
|
|
||||||
def add_z3_leaf_module(model: "PreTrainedModel", module: "torch.nn.Module") -> None:
|
|
||||||
r"""
|
|
||||||
Sets module as a leaf module to skip partitioning in deepspeed zero3.
|
|
||||||
"""
|
|
||||||
if is_deepspeed_zero3_enabled():
|
|
||||||
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
|
|
||||||
from deepspeed.utils import set_z3_leaf_modules # type: ignore
|
|
||||||
|
|
||||||
set_z3_leaf_modules(model, [module])
|
|
||||||
|
|
||||||
|
|
||||||
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
||||||
r"""
|
r"""
|
||||||
Finds all available modules to apply lora or galore.
|
Finds all available modules to apply lora or galore.
|
||||||
@ -102,42 +74,6 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
|
|||||||
return module_names
|
return module_names
|
||||||
|
|
||||||
|
|
||||||
def gradient_checkpointing_enable(
|
|
||||||
self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
|
||||||
) -> None:
|
|
||||||
r"""
|
|
||||||
Activates gradient checkpointing for the current model.
|
|
||||||
|
|
||||||
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
|
|
||||||
"""
|
|
||||||
from torch.utils.checkpoint import checkpoint
|
|
||||||
|
|
||||||
if not self.supports_gradient_checkpointing:
|
|
||||||
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))
|
|
||||||
|
|
||||||
if gradient_checkpointing_kwargs is None:
|
|
||||||
gradient_checkpointing_kwargs = {"use_reentrant": True}
|
|
||||||
|
|
||||||
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
|
|
||||||
|
|
||||||
def custom_gradient_checkpointing_func(func, *args, **kwargs):
|
|
||||||
module: "torch.nn.Module" = func.__self__
|
|
||||||
|
|
||||||
if any(param.requires_grad for param in module.parameters()):
|
|
||||||
for arg in args:
|
|
||||||
if torch.is_tensor(arg) and torch.is_floating_point(arg):
|
|
||||||
arg.requires_grad_(True)
|
|
||||||
|
|
||||||
return gradient_checkpointing_func(func, *args, **kwargs)
|
|
||||||
|
|
||||||
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
|
|
||||||
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
|
||||||
self.enable_input_require_grads()
|
|
||||||
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
|
|
||||||
else: # have already enabled input require gradients
|
|
||||||
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
|
|
||||||
|
|
||||||
|
|
||||||
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
||||||
r"""
|
r"""
|
||||||
Loads value head parameters from Hugging Face Hub or local disk.
|
Loads value head parameters from Hugging Face Hub or local disk.
|
39
src/llmtuner/model/utils/moe.py
Normal file
39
src/llmtuner/model/utils/moe.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
|
def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
||||||
|
r"""
|
||||||
|
Sets module as a leaf module to skip partitioning in deepspeed zero3.
|
||||||
|
"""
|
||||||
|
if not is_deepspeed_zero3_enabled():
|
||||||
|
return
|
||||||
|
|
||||||
|
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
|
||||||
|
from deepspeed.utils import set_z3_leaf_modules # type: ignore
|
||||||
|
|
||||||
|
if getattr(model.config, "model_type", None) == "mixtral":
|
||||||
|
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||||
|
|
||||||
|
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
|
||||||
|
|
||||||
|
if getattr(model.config, "model_type", None) == "qwen2moe":
|
||||||
|
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
|
||||||
|
|
||||||
|
set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
|
||||||
|
|
||||||
|
if getattr(model.config, "model_type", None) == "jamba":
|
||||||
|
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
|
||||||
|
|
||||||
|
set_z3_leaf_modules(model, [JambaSparseMoeBlock])
|
||||||
|
|
||||||
|
if getattr(model.config, "model_type", None) == "dbrx":
|
||||||
|
from transformers.models.dbrx.modeling_dbrx import DbrxFFN
|
||||||
|
|
||||||
|
set_z3_leaf_modules(model, [DbrxFFN])
|
146
src/llmtuner/model/utils/quantization.py
Normal file
146
src/llmtuner/model/utils/quantization.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
from enum import Enum, unique
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
|
from transformers import BitsAndBytesConfig, GPTQConfig
|
||||||
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from ...extras.constants import FILEEXT2TYPE
|
||||||
|
from ...extras.logging import get_logger
|
||||||
|
from ...extras.misc import get_current_device
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||||
|
|
||||||
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@unique
|
||||||
|
class QuantizationMethod(str, Enum):
|
||||||
|
r"""
|
||||||
|
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
BITS_AND_BYTES = "bitsandbytes"
|
||||||
|
GPTQ = "gptq"
|
||||||
|
AWQ = "awq"
|
||||||
|
AQLM = "aqlm"
|
||||||
|
QUANTO = "quanto"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
|
||||||
|
r"""
|
||||||
|
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
|
||||||
|
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
|
||||||
|
"""
|
||||||
|
if os.path.isfile(model_args.export_quantization_dataset):
|
||||||
|
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
|
||||||
|
data_files = model_args.export_quantization_dataset
|
||||||
|
else:
|
||||||
|
data_path = model_args.export_quantization_dataset
|
||||||
|
data_files = None
|
||||||
|
|
||||||
|
dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
|
||||||
|
maxlen = model_args.export_quantization_maxlen
|
||||||
|
|
||||||
|
samples = []
|
||||||
|
for _ in range(model_args.export_quantization_nsamples):
|
||||||
|
while True:
|
||||||
|
sample_idx = random.randint(0, len(dataset) - 1)
|
||||||
|
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
|
||||||
|
if sample["input_ids"].size(1) >= maxlen:
|
||||||
|
break # TODO: fix large maxlen
|
||||||
|
|
||||||
|
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
|
||||||
|
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
|
||||||
|
samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
|
||||||
|
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def configure_quantization(
|
||||||
|
config: "PretrainedConfig",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
init_kwargs: Dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
r"""
|
||||||
|
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
||||||
|
"""
|
||||||
|
if getattr(config, "quantization_config", None): # ptq
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.")
|
||||||
|
|
||||||
|
if model_args.quantization_device_map != "auto":
|
||||||
|
init_kwargs["device_map"] = {"": get_current_device()}
|
||||||
|
|
||||||
|
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||||
|
quant_method = quantization_config.get("quant_method", "")
|
||||||
|
|
||||||
|
if quant_method == QuantizationMethod.GPTQ:
|
||||||
|
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
||||||
|
quantization_config.pop("disable_exllama", None) # remove deprecated args
|
||||||
|
quantization_config["use_exllama"] = False # disable exllama
|
||||||
|
|
||||||
|
if quant_method == QuantizationMethod.AWQ:
|
||||||
|
require_version("autoawq", "To fix: pip install autoawq")
|
||||||
|
|
||||||
|
if quant_method == QuantizationMethod.AQLM:
|
||||||
|
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
|
||||||
|
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
|
||||||
|
quantization_config["bits"] = 2
|
||||||
|
|
||||||
|
quant_bits = quantization_config.get("bits", "?")
|
||||||
|
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
|
||||||
|
|
||||||
|
elif model_args.export_quantization_bit is not None: # auto-gptq
|
||||||
|
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
||||||
|
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
||||||
|
from accelerate.utils import get_max_memory
|
||||||
|
|
||||||
|
if getattr(config, "model_type", None) == "chatglm":
|
||||||
|
raise ValueError("ChatGLM model is not supported.")
|
||||||
|
|
||||||
|
init_kwargs["quantization_config"] = GPTQConfig(
|
||||||
|
bits=model_args.export_quantization_bit,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
dataset=_get_quantization_dataset(tokenizer, model_args),
|
||||||
|
)
|
||||||
|
init_kwargs["device_map"] = "auto"
|
||||||
|
init_kwargs["max_memory"] = get_max_memory()
|
||||||
|
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
|
||||||
|
|
||||||
|
elif model_args.quantization_bit is not None: # bnb
|
||||||
|
if model_args.quantization_bit == 8:
|
||||||
|
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||||
|
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||||
|
|
||||||
|
elif model_args.quantization_bit == 4:
|
||||||
|
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||||
|
init_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||||
|
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||||
|
bnb_4bit_quant_type=model_args.quantization_type,
|
||||||
|
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_deepspeed_zero3_enabled() or model_args.quantization_device_map == "auto":
|
||||||
|
if model_args.quantization_bit != 4:
|
||||||
|
raise ValueError("Only 4-bit quantized model can use auto device map.")
|
||||||
|
|
||||||
|
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
|
||||||
|
require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0")
|
||||||
|
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
|
||||||
|
else:
|
||||||
|
init_kwargs["device_map"] = {"": get_current_device()}
|
||||||
|
|
||||||
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
43
src/llmtuner/model/utils/rope.py
Normal file
43
src/llmtuner/model/utils/rope.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import math
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...extras.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||||
|
if model_args.rope_scaling is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not hasattr(config, "rope_scaling"):
|
||||||
|
logger.warning("Current model does not support RoPE scaling.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if is_trainable:
|
||||||
|
if model_args.rope_scaling == "dynamic":
|
||||||
|
logger.warning(
|
||||||
|
"Dynamic NTK scaling may not work well with fine-tuning. "
|
||||||
|
"See: https://github.com/huggingface/transformers/pull/24653"
|
||||||
|
)
|
||||||
|
|
||||||
|
current_max_length = getattr(config, "max_position_embeddings", None)
|
||||||
|
if current_max_length and model_args.model_max_length > current_max_length:
|
||||||
|
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
||||||
|
else:
|
||||||
|
logger.warning("Input length is smaller than max length. Consider increase input length.")
|
||||||
|
scaling_factor = 1.0
|
||||||
|
else:
|
||||||
|
scaling_factor = 2.0
|
||||||
|
|
||||||
|
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
||||||
|
logger.info(
|
||||||
|
"Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor)
|
||||||
|
)
|
Loading…
x
Reference in New Issue
Block a user