rename package

Former-commit-id: 308edbc426
This commit is contained in:
hiyouga
2024-05-16 18:39:08 +08:00
parent 93a289107b
commit cae823ddf0
109 changed files with 31 additions and 31 deletions

View File

@@ -0,0 +1,12 @@
from .loader import load_config, load_model, load_tokenizer
from .utils.misc import find_all_linear_modules
from .utils.valuehead import load_valuehead_params
__all__ = [
"load_config",
"load_model",
"load_tokenizer",
"load_valuehead_params",
"find_all_linear_modules",
]

View File

@@ -0,0 +1,225 @@
import re
from typing import TYPE_CHECKING
import torch
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from transformers.integrations import deepspeed_config, is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from ..extras.logging import get_logger
from .utils.misc import find_all_linear_modules, find_expanded_modules
from .utils.quantization import QuantizationMethod
from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
def init_adapter(
config: "PretrainedConfig",
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
) -> "PreTrainedModel":
r"""
Initializes the adapters.
Support full-parameter, freeze and LoRA training.
Note that the trainable parameters must be cast to float32.
"""
if (not is_trainable) and model_args.adapter_name_or_path is None:
logger.info("Adapter is not found at evaluation, load the base model.")
return model
if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None):
raise ValueError("You can only use lora for quantized models.")
if deepspeed_config() is not None or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
logger.info("DeepSpeed/FSDP/PureBF16/BAdam detected, remaining trainable params in half precision.")
cast_trainable_params_to_fp32 = False
else:
logger.info("Upcasting trainable params to float32.")
cast_trainable_params_to_fp32 = True
if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full")
if cast_trainable_params_to_fp32:
model = model.float()
if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model
model.vision_tower.requires_grad_(False)
if finetuning_args.finetuning_type == "freeze" and is_trainable:
logger.info("Fine-tuning method: Freeze")
num_layers = (
getattr(model.config, "num_hidden_layers", None)
or getattr(model.config, "num_layers", None)
or getattr(model.config, "n_layer", None)
)
if not num_layers:
raise ValueError("Current model does not support freeze tuning.")
if finetuning_args.use_llama_pro:
if num_layers % finetuning_args.freeze_trainable_layers != 0:
raise ValueError(
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(
num_layers, finetuning_args.freeze_trainable_layers
)
)
stride = num_layers // finetuning_args.freeze_trainable_layers
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
elif finetuning_args.freeze_trainable_layers > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = range(max(0, num_layers - finetuning_args.freeze_trainable_layers), num_layers)
else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = range(min(-finetuning_args.freeze_trainable_layers, num_layers))
hidden_modules = set()
non_hidden_modules = set()
for name, _ in model.named_parameters():
if ".0." in name:
hidden_modules.add(name.split(".0.")[-1].split(".")[0])
elif ".1." in name: # MoD starts from layer 1
hidden_modules.add(name.split(".1.")[-1].split(".")[0])
if re.search(r"\.\d+\.", name) is None:
non_hidden_modules.add(name.split(".")[-2])
trainable_layers = []
for module_name in finetuning_args.freeze_trainable_modules:
if module_name != "all" and module_name not in hidden_modules:
raise ValueError(
"Module {} is not found, please choose from {}".format(module_name, ", ".join(hidden_modules))
)
for idx in trainable_layer_ids:
trainable_layers.append(".{:d}.{}".format(idx, module_name if module_name != "all" else ""))
if finetuning_args.freeze_extra_modules:
for module_name in finetuning_args.freeze_extra_modules:
if module_name not in non_hidden_modules:
raise ValueError(
"Module {} is not found, please choose from {}".format(
module_name, ", ".join(non_hidden_modules)
)
)
trainable_layers.append(module_name)
for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers):
if cast_trainable_params_to_fp32:
param.data = param.data.to(torch.float32)
else:
param.requires_grad_(False)
logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids))))
if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
adapter_to_resume = None
if model_args.adapter_name_or_path is not None:
is_mergeable = True
if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable
assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter."
is_mergeable = False
if is_deepspeed_zero3_enabled():
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
is_mergeable = False
if model_args.use_unsloth:
assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter."
is_mergeable = False
if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):
adapter_to_merge = model_args.adapter_name_or_path[:-1]
adapter_to_resume = model_args.adapter_name_or_path[-1]
else:
adapter_to_merge = model_args.adapter_name_or_path
for adapter in adapter_to_merge:
model: "LoraModel" = PeftModel.from_pretrained(
model, adapter, offload_folder=model_args.offload_folder
)
model = model.merge_and_unload()
if len(adapter_to_merge) > 0:
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
if adapter_to_resume is not None: # resume lora training
if model_args.use_unsloth:
model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable)
else:
model = PeftModel.from_pretrained(
model,
adapter_to_resume,
is_trainable=is_trainable,
offload_folder=model_args.offload_folder,
)
if is_trainable and adapter_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
target_modules = find_all_linear_modules(model)
else:
target_modules = finetuning_args.lora_target
if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
if (
finetuning_args.use_dora
and getattr(model, "quantization_method", None) is not None
and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES
):
raise ValueError("DoRA is not compatible with PTQ-quantized models.")
if model_args.resize_vocab and finetuning_args.additional_target is None:
input_embeddings = model.get_input_embeddings()
output_embeddings = model.get_output_embeddings()
module_names = set()
for name, module in model.named_modules():
if module in [input_embeddings, output_embeddings]:
module_names.add(name.split(".")[-1])
finetuning_args.additional_target = module_names
logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
peft_kwargs = {
"r": finetuning_args.lora_rank,
"target_modules": target_modules,
"lora_alpha": finetuning_args.lora_alpha,
"lora_dropout": finetuning_args.lora_dropout,
"use_rslora": finetuning_args.use_rslora,
"modules_to_save": finetuning_args.additional_target,
}
if model_args.use_unsloth:
model = get_unsloth_peft_model(model, model_args, peft_kwargs)
else:
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
use_dora=finetuning_args.use_dora,
**peft_kwargs,
)
model = get_peft_model(model, lora_config)
if cast_trainable_params_to_fp32:
for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.float32)
if model_args.adapter_name_or_path is not None:
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
return model

View File

@@ -0,0 +1,183 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger
from ..extras.misc import count_parameters, try_download_model_from_ms
from .adapter import init_adapter
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
from .utils.misc import register_autoclass
from .utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .utils.unsloth import load_unsloth_pretrained_model
from .utils.valuehead import load_valuehead_params
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
class TokenizerModule(TypedDict):
tokenizer: "PreTrainedTokenizer"
processor: Optional["ProcessorMixin"]
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
r"""
Gets arguments to load config/tokenizer/model.
Note: including inplace operation of model_args.
"""
model_args.model_name_or_path = try_download_model_from_ms(model_args)
return {
"trust_remote_code": True,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"token": model_args.hf_hub_token,
}
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
r"""
Loads pretrained tokenizer.
Note: including inplace operation of model_args.
"""
init_kwargs = _get_init_kwargs(model_args)
try:
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
split_special_tokens=model_args.split_special_tokens,
padding_side="right",
**init_kwargs,
)
except ValueError: # try the fast one
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=True,
padding_side="right",
**init_kwargs,
)
if model_args.new_special_tokens is not None:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=model_args.new_special_tokens),
replace_additional_special_tokens=False,
)
logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
if num_added_tokens > 0 and not model_args.resize_vocab:
model_args.resize_vocab = True
logger.warning("New tokens have been added, changed `resize_vocab` to True.")
patch_tokenizer(tokenizer)
if model_args.visual_inputs:
try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
setattr(processor, "tokenizer", tokenizer)
except Exception:
raise ValueError(
"This multimodal LLM is not supported.\n"
"Download LLaVA-1.5 models from: https://huggingface.co/llava-hf\n"
"Download Yi-VL models from: https://huggingface.co/BUAADreamer"
)
else:
processor = None
return {"tokenizer": tokenizer, "processor": processor}
def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
r"""
Loads model config.
"""
init_kwargs = _get_init_kwargs(model_args)
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
def load_model(
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool = False,
add_valuehead: bool = False,
) -> "PreTrainedModel":
r"""
Loads pretrained model.
"""
init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
model = None
lazy_load = False
if model_args.use_unsloth:
if model_args.adapter_name_or_path is not None:
lazy_load = True
elif is_trainable:
model = load_unsloth_pretrained_model(config, model_args)
if model is None and not lazy_load:
init_kwargs["config"] = config
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
if model_args.mixture_of_depths == "load":
model = load_mod_pretrained_model(**init_kwargs)
elif model_args.visual_inputs:
model = AutoModelForVision2Seq.from_pretrained(**init_kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(**init_kwargs)
if model_args.mixture_of_depths == "convert":
model = convert_pretrained_model_to_mod(model, config, model_args)
if not lazy_load:
patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
register_autoclass(config, model, tokenizer)
model = init_adapter(config, model, model_args, finetuning_args, is_trainable)
if add_valuehead:
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
patch_valuehead_model(model)
if model_args.adapter_name_or_path is not None:
vhead_path = model_args.adapter_name_or_path[-1]
else:
vhead_path = model_args.model_name_or_path
vhead_params = load_valuehead_params(vhead_path, model_args)
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
if not is_trainable:
model.requires_grad_(False)
model.eval()
else:
model.train()
trainable_params, all_param = count_parameters(model)
if is_trainable:
param_stats = "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
)
else:
param_stats = "all params: {:d}".format(all_param)
logger.info(param_stats)
if model_args.print_param_status:
for name, param in model.named_parameters():
print(
"name: {}, dtype: {}, device: {}, trainable: {}".format(
name, param.dtype, param.device, param.requires_grad
)
)
return model

View File

@@ -0,0 +1,140 @@
import os
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict
import torch
from peft import PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from ..extras.logging import get_logger
from ..extras.misc import infer_optim_dtype
from .utils.attention import configure_attn_implementation, print_attn_implementation
from .utils.checkpointing import prepare_model_for_training
from .utils.embedding import resize_embedding_layer
from .utils.longlora import configure_longlora
from .utils.moe import add_z3_leaf_module, configure_moe
from .utils.quantization import configure_quantization
from .utils.rope import configure_rope
from .utils.valuehead import prepare_valuehead_model
from .utils.visual import autocast_projector_dtype, configure_visual_model
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import ModelArguments
logger = get_logger(__name__)
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
def patch_config(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
init_kwargs: Dict[str, Any],
is_trainable: bool,
) -> None:
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
if is_torch_npu_available():
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"]
torch.npu.set_compile_mode(jit_compile=use_jit_compile)
configure_attn_implementation(config, model_args)
configure_rope(config, model_args, is_trainable)
configure_longlora(config, model_args, is_trainable)
configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_moe(config, model_args, is_trainable)
configure_visual_model(config)
if model_args.use_cache and not is_trainable:
setattr(config, "use_cache", True)
logger.info("Using KV cache for faster generation.")
if getattr(config, "model_type", None) == "qwen":
setattr(config, "use_flash_attn", model_args.flash_attn == "fa2")
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
setattr(config, dtype_name, model_args.compute_dtype == dtype)
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn:
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn
init_kwargs["torch_dtype"] = model_args.compute_dtype
if not is_deepspeed_zero3_enabled() and not is_fsdp_enabled():
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
if init_kwargs["low_cpu_mem_usage"]:
if "device_map" not in init_kwargs and model_args.device_map:
init_kwargs["device_map"] = model_args.device_map
if init_kwargs["device_map"] == "auto":
init_kwargs["offload_folder"] = model_args.offload_folder
def patch_model(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
is_trainable: bool,
add_valuehead: bool,
) -> None:
gen_config = model.generation_config # check and fix generation config
if not gen_config.do_sample and (
(gen_config.temperature is not None and gen_config.temperature != 1.0)
or (gen_config.top_p is not None and gen_config.top_p != 1.0)
or (gen_config.typical_p is not None and gen_config.typical_p != 1.0)
):
gen_config.do_sample = True
if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
if add_valuehead:
prepare_valuehead_model(model)
if model_args.resize_vocab:
resize_embedding_layer(model, tokenizer)
if model_args.visual_inputs:
autocast_projector_dtype(model, model_args)
if is_trainable:
prepare_model_for_training(model, model_args)
add_z3_leaf_module(model)
if not model_args.use_unsloth:
print_attn_implementation(model.config)
try:
model.add_model_tags(["llama-factory"])
except Exception:
logger.warning("Cannot properly tag the model.")
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None:
if isinstance(self.pretrained_model, PreTrainedModel):
self.pretrained_model.tie_weights()
def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
if isinstance(self.pretrained_model, PreTrainedModel):
return self.pretrained_model.get_input_embeddings()
def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
if isinstance(self.pretrained_model, PeftModel):
self.pretrained_model.create_or_update_model_card(output_dir)
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
setattr(model, "tie_weights", MethodType(tie_weights, model))
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model))

View File

View File

@@ -0,0 +1,55 @@
from typing import TYPE_CHECKING
from ...extras.logging import get_logger
from ...extras.packages import is_flash_attn2_available, is_sdpa_available
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
logger = get_logger(__name__)
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
if model_args.flash_attn == "auto":
return
elif model_args.flash_attn == "off":
requested_attn_implementation = "eager"
elif model_args.flash_attn == "sdpa":
if not is_sdpa_available():
logger.warning("torch>=2.1.1 is required for SDPA attention.")
return
requested_attn_implementation = "sdpa"
elif model_args.flash_attn == "fa2":
if not is_flash_attn2_available():
logger.warning("FlashAttention-2 is not installed.")
return
requested_attn_implementation = "flash_attention_2"
else:
raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn))
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
setattr(config, "attn_implementation", requested_attn_implementation)
else:
setattr(config, "_attn_implementation", requested_attn_implementation)
def print_attn_implementation(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
attn_implementation = getattr(config, "attn_implementation", None)
else:
attn_implementation = getattr(config, "_attn_implementation", None)
if attn_implementation == "flash_attention_2":
logger.info("Using FlashAttention-2 for faster training and inference.")
elif attn_implementation == "sdpa":
logger.info("Using torch SDPA for faster training and inference.")
else:
logger.info("Using vanilla attention implementation.")

View File

@@ -0,0 +1,94 @@
import inspect
from functools import partial
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import torch
from ...extras.constants import LAYERNORM_NAMES
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedModel
from ...hparams import ModelArguments
logger = get_logger(__name__)
def _gradient_checkpointing_enable(
self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
) -> None:
r"""
Activates gradient checkpointing for the current model.
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
"""
from torch.utils.checkpoint import checkpoint
if not self.supports_gradient_checkpointing:
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True}
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
def custom_gradient_checkpointing_func(func, *args, **kwargs):
module: "torch.nn.Module" = func.__self__
if any(param.requires_grad for param in module.parameters()):
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)
return gradient_checkpointing_func(func, *args, **kwargs)
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads()
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
def _fp32_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor":
return output.to(torch.float32)
def prepare_model_for_training(
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head"
) -> None:
r"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
"""
if model_args.upcast_layernorm:
logger.info("Upcasting layernorm weights in float32.")
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
param.data = param.data.to(torch.float32)
if not model_args.disable_gradient_checkpointing:
if not getattr(model, "supports_gradient_checkpointing", False):
logger.warning("Current model does not support gradient checkpointing.")
else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339
model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
logger.info("Upcasting lm_head outputs in float32.")
output_layer = getattr(model, output_layer_name)
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
output_layer.register_forward_hook(_fp32_forward_post_hook)

View File

@@ -0,0 +1,58 @@
import math
from contextlib import nullcontext
from typing import TYPE_CHECKING
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
logger = get_logger(__name__)
def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int) -> None:
embedding_dim = embed_weight.size(1)
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
r"""
Resize token embeddings.
"""
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
params = [model.get_input_embeddings().weight]
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
params.append(model.get_output_embeddings().weight)
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
else:
context_maybe_zero3 = nullcontext()
with context_maybe_zero3:
current_embedding_size = model.get_input_embeddings().weight.size(0)
if len(tokenizer) > current_embedding_size:
if getattr(model, "quantization_method", None):
raise ValueError("Cannot resize embedding layers of a quantized model.")
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
raise ValueError("Current model does not support resizing embedding layers.")
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))

View File

@@ -0,0 +1,323 @@
import math
from typing import TYPE_CHECKING, Optional, Tuple
import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import (
Cache,
LlamaAttention,
LlamaFlashAttention2,
LlamaSdpaAttention,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import logging
from transformers.utils.versions import require_version
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
logger = logging.get_logger(__name__)
# Modified from:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_attention_forward(
self: "LlamaAttention",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor:
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
state = torch.cat(
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
dim=2,
)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz * n_group, :, groupsz, :)
attn_output = attn_output.transpose(1, 2).contiguous()
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
)
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Modified from:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_flash_attention_2_forward(
self: "LlamaFlashAttention2",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once("The input hidden states seems to be silently casted in float32.")
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor:
state = torch.cat(
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
dim=2,
)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
else:
groupsz = q_len
attn_output: torch.Tensor = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, groupsz, dropout=dropout_rate
)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
)
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Modified from:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_sdpa_attention_forward(
self: "LlamaSdpaAttention",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
logger.warning_once("SDPA does not support `output_attentions=True`. Falling back to the vanilla attention")
return llama_attention_forward(
self,
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
cache_position=cache_position,
**kwargs,
)
bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: torch.Tensor) -> torch.Tensor:
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
state = torch.cat(
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
dim=2,
)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1,
)
attn_output = attn_output.transpose(1, 2).contiguous()
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
)
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def _apply_llama_patch() -> None:
require_version("transformers==4.40.2", "To fix: pip install transformers==4.40.2")
LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.shift_attn:
return
logger = get_logger(__name__)
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
setattr(config, "group_size_ratio", 0.25)
_apply_llama_patch()
logger.info("Using shift short attention with group_size_ratio=1/4.")
else:
logger.warning("Current model does not support shift short attention.")

View File

@@ -0,0 +1,78 @@
from typing import TYPE_CHECKING, List
import torch
from ...extras.logging import get_logger
from .quantization import QuantizationMethod
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
logger = get_logger(__name__)
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
r"""
Finds all available modules to apply lora or galore.
"""
quantization_method = getattr(model, "quantization_method", None)
if quantization_method is None:
linear_cls = torch.nn.Linear
elif quantization_method == QuantizationMethod.BITS_AND_BYTES:
import bitsandbytes as bnb
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt
else:
raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method))
output_layer_names = ["lm_head"]
if model.config.model_type == "chatglm":
output_layer_names.append("output_layer")
elif model.config.model_type == "internlm2":
output_layer_names.append("output")
module_names = set()
for name, module in model.named_modules():
if isinstance(module, linear_cls) and not any(output_layer in name for output_layer in output_layer_names):
module_names.add(name.split(".")[-1])
logger.info("Found linear modules: {}".format(",".join(module_names)))
return list(module_names)
def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], num_layer_trainable: int) -> List[str]:
r"""
Finds the modules in the expanded blocks to apply lora.
"""
num_layers = getattr(model.config, "num_hidden_layers", None)
if not num_layers:
raise ValueError("Model was not supported.")
if num_layers % num_layer_trainable != 0:
raise ValueError(
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(num_layers, num_layer_trainable)
)
stride = num_layers // num_layer_trainable
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
trainable_layers = [".{:d}.".format(idx) for idx in trainable_layer_ids]
module_names = []
for name, _ in model.named_modules():
if any(target_module in name for target_module in target_modules) and any(
trainable_layer in name for trainable_layer in trainable_layers
):
module_names.append(name)
logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
return module_names
def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"):
if "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()
if "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
model.__class__.register_for_auto_class()
if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
tokenizer.__class__.register_for_auto_class()

View File

@@ -0,0 +1,28 @@
from typing import TYPE_CHECKING
from ...extras.constants import MOD_SUPPORTED_MODELS
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
def load_mod_pretrained_model(**init_kwargs) -> "PreTrainedModel":
from MoD import AutoMoDModelForCausalLM
return AutoMoDModelForCausalLM.from_pretrained(**init_kwargs)
def convert_pretrained_model_to_mod(
model: "PreTrainedModel", config: "PretrainedConfig", model_args: "ModelArguments"
) -> "PreTrainedModel":
from MoD import apply_mod_to_hf
if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS:
raise ValueError("Current model is not supported by mixture-of-depth.")
model = apply_mod_to_hf(model)
model = model.to(model_args.compute_dtype)
return model

View File

@@ -0,0 +1,53 @@
from typing import TYPE_CHECKING
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
def add_z3_leaf_module(model: "PreTrainedModel") -> None:
r"""
Sets module as a leaf module to skip partitioning in deepspeed zero3.
"""
if not is_deepspeed_zero3_enabled():
return
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules # type: ignore
if getattr(model.config, "model_type", None) == "mixtral":
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
if getattr(model.config, "model_type", None) == "qwen2moe":
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
if getattr(model.config, "model_type", None) == "jamba":
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
set_z3_leaf_modules(model, [JambaSparseMoeBlock])
if getattr(model.config, "model_type", None) == "dbrx":
from transformers.models.dbrx.modeling_dbrx import DbrxFFN
set_z3_leaf_modules(model, [DbrxFFN])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if model_args.moe_aux_loss_coef is not None:
if getattr(config, "model_type", None) in ["jamba", "mixtral", "qwen2_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif getattr(config, "model_type", None) == "deepseek":
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
if getattr(config, "model_type", None) in ["dbrx", "jamba", "mixtral", "qwen2_moe"]:
setattr(config, "output_router_logits", is_trainable)

View File

@@ -0,0 +1,147 @@
import os
import random
from enum import Enum, unique
from typing import TYPE_CHECKING, Any, Dict, List
import torch
from datasets import load_dataset
from transformers import BitsAndBytesConfig, GPTQConfig
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils.versions import require_version
from ...extras.constants import FILEEXT2TYPE
from ...extras.logging import get_logger
from ...extras.misc import get_current_device
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from ...hparams import ModelArguments
logger = get_logger(__name__)
@unique
class QuantizationMethod(str, Enum):
r"""
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
"""
BITS_AND_BYTES = "bitsandbytes"
GPTQ = "gptq"
AWQ = "awq"
AQLM = "aqlm"
QUANTO = "quanto"
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
r"""
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
"""
if os.path.isfile(model_args.export_quantization_dataset):
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
data_files = model_args.export_quantization_dataset
else:
data_path = model_args.export_quantization_dataset
data_files = None
dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
maxlen = model_args.export_quantization_maxlen
samples = []
for _ in range(model_args.export_quantization_nsamples):
while True:
sample_idx = random.randint(0, len(dataset) - 1)
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
if sample["input_ids"].size(1) >= maxlen:
break # TODO: fix large maxlen
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
return samples
def configure_quantization(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
init_kwargs: Dict[str, Any],
) -> None:
r"""
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
"""
if getattr(config, "quantization_config", None): # ptq
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.")
if model_args.quantization_device_map != "auto":
init_kwargs["device_map"] = {"": get_current_device()}
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ:
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
quantization_config.pop("disable_exllama", None) # remove deprecated args
quantization_config["use_exllama"] = False # disable exllama
if quant_method == QuantizationMethod.AWQ:
require_version("autoawq", "To fix: pip install autoawq")
if quant_method == QuantizationMethod.AQLM:
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
quantization_config["bits"] = 2
quant_bits = quantization_config.get("bits", "?")
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
elif model_args.export_quantization_bit is not None: # auto-gptq
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm":
raise ValueError("ChatGLM model is not supported.")
init_kwargs["quantization_config"] = GPTQConfig(
bits=model_args.export_quantization_bit,
tokenizer=tokenizer,
dataset=_get_quantization_dataset(tokenizer, model_args),
)
init_kwargs["device_map"] = "auto"
init_kwargs["max_memory"] = get_max_memory()
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
elif model_args.quantization_bit is not None: # bnb
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
init_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type,
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora
)
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto":
if model_args.quantization_bit != 4:
raise ValueError("Only 4-bit quantized model can use auto device map.")
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0")
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
else:
init_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))

View File

@@ -0,0 +1,47 @@
import math
from typing import TYPE_CHECKING
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
logger = get_logger(__name__)
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if model_args.rope_scaling is None:
return
if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.")
return
if is_trainable:
if model_args.rope_scaling == "dynamic":
logger.warning(
"Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length > current_max_length:
logger.info(
"Enlarge max model length from {} to {}.".format(current_max_length, model_args.model_max_length)
)
setattr(config, "max_position_embeddings", model_args.model_max_length)
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
else:
logger.warning("Input length is smaller than max length. Consider increase input length.")
scaling_factor = 1.0
else:
scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info(
"Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor)
)

View File

@@ -0,0 +1,88 @@
from typing import TYPE_CHECKING, Any, Dict, Optional
from ...extras.logging import get_logger
from ...extras.misc import get_current_device
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
logger = get_logger(__name__)
def _get_unsloth_kwargs(
config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments"
) -> Dict[str, Any]:
return {
"model_name": model_name_or_path,
"max_seq_length": model_args.model_max_length or 4096,
"dtype": model_args.compute_dtype,
"load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token,
"device_map": {"": get_current_device()},
"rope_scaling": getattr(config, "rope_scaling", None),
"fix_tokenizer": False,
"trust_remote_code": True,
"use_gradient_checkpointing": "unsloth",
}
def load_unsloth_pretrained_model(
config: "PretrainedConfig", model_args: "ModelArguments"
) -> Optional["PreTrainedModel"]:
r"""
Optionally loads pretrained model with unsloth. Used in training.
"""
from unsloth import FastLanguageModel
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args)
try:
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError:
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
model = None
model_args.use_unsloth = False
return model
def get_unsloth_peft_model(
model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any]
) -> "PreTrainedModel":
r"""
Gets the peft model for the pretrained model with unsloth. Used in training.
"""
from unsloth import FastLanguageModel
unsloth_peft_kwargs = {
"model": model,
"max_seq_length": model_args.model_max_length,
"use_gradient_checkpointing": "unsloth",
}
return FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
def load_unsloth_peft_model(
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
) -> "PreTrainedModel":
r"""
Loads peft model with unsloth. Used in both training and inference.
"""
from unsloth import FastLanguageModel
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args)
try:
if not is_trainable:
unsloth_kwargs["use_gradient_checkpointing"] = False
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError:
raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
if not is_trainable:
FastLanguageModel.for_inference(model)
return model

View File

@@ -0,0 +1,58 @@
from typing import TYPE_CHECKING, Dict
import torch
from transformers.utils import cached_file
from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedModel
from ...hparams import ModelArguments
logger = get_logger(__name__)
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
try:
from safetensors import safe_open
vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
with safe_open(vhead_file, framework="pt", device="cpu") as f:
return {key: f.get_tensor(key) for key in f.keys()}
except Exception as err:
logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err)))
try:
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
return torch.load(vhead_file, map_location="cpu")
except Exception as err:
logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err)))
logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id))
logger.info("Ignore these messages if you are not resuming the training of a value head model.")
return None
def prepare_valuehead_model(model: "PreTrainedModel") -> None:
if getattr(model.config, "model_type", None) == "llava":
setattr(model, "lm_head", model.language_model.get_output_embeddings())
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
if getattr(model.config, "model_type", None) == "chatglm":
setattr(model, "lm_head", model.transformer.output_layer)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
if getattr(model.config, "model_type", None) == "internlm2":
setattr(model, "lm_head", model.output)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])

View File

@@ -0,0 +1,84 @@
from typing import TYPE_CHECKING, Tuple
import torch
import transformers.models
from transformers.activations import ACT2FN
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
logger = get_logger(__name__)
class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
def __init__(self, config: "LlavaConfig") -> None:
super().__init__()
self.config = config
if config is None:
return
self.linear_1 = torch.nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
self.linear_2 = torch.nn.LayerNorm(config.text_config.hidden_size, bias=True)
self.linear_3 = torch.nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
self.linear_4 = torch.nn.LayerNorm(config.text_config.hidden_size, bias=True)
self.act = ACT2FN[config.projector_hidden_act]
def forward(self, image_features: "torch.Tensor") -> "torch.Tensor":
hidden_states = self.linear_1(image_features)
hidden_states = self.linear_2(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_3(hidden_states)
hidden_states = self.linear_4(hidden_states)
if hidden_states.dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.linear_1.weight.dtype
logger.warning_once("The hidden states seems to be silently casted in float32.")
hidden_states = hidden_states.to(target_dtype)
return hidden_states
class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
def __init__(self, vision_hidden_size: int, text_hidden_size: int, projector_hidden_act: str) -> None:
super().__init__(config=None)
self.linear_1 = torch.nn.Linear(vision_hidden_size, text_hidden_size, bias=True)
self.linear_2 = torch.nn.LayerNorm(text_hidden_size, bias=True)
self.linear_3 = torch.nn.Linear(text_hidden_size, text_hidden_size, bias=True)
self.linear_4 = torch.nn.LayerNorm(text_hidden_size, bias=True)
self.act = ACT2FN[projector_hidden_act]
def autocast_projector_dtype(
model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector"
) -> None:
def _mm_projector_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor":
return output.to(model_args.compute_dtype)
if hasattr(model, mm_projector_name) and getattr(model.config, "quantization_method", None):
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
def configure_visual_model(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) == "llava": # required for ds zero3 and valuehead models
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
if getattr(config, "is_yi_vl_derived_model", None):
logger.info("Detected Yi-VL model, applying projector patch.")
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL