mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 04:32:50 +08:00
update patcher
Former-commit-id: e44b82ee245a7ee99057c7b58b1edef5c222dc1f
This commit is contained in:
parent
11a6c8a9a0
commit
5d440f978e
@ -214,7 +214,7 @@ huggingface-cli login
|
||||
|
||||
| Method | Bits | 7B | 13B | 30B | 65B | 8x7B |
|
||||
| ------ | ---- | ----- | ----- | ----- | ------ | ------ |
|
||||
| Full | 16 | 160GB | 320GB | 600GB | 1200GB | 1000GB |
|
||||
| Full | 16 | 160GB | 320GB | 600GB | 1200GB | 900GB |
|
||||
| Freeze | 16 | 20GB | 40GB | 120GB | 240GB | 200GB |
|
||||
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB | 120GB |
|
||||
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | 80GB |
|
||||
|
@ -214,7 +214,7 @@ huggingface-cli login
|
||||
|
||||
| 训练方法 | 精度 | 7B | 13B | 30B | 65B | 8x7B |
|
||||
| ------- | ---- | ----- | ----- | ----- | ------ | ------ |
|
||||
| 全参数 | 16 | 160GB | 320GB | 600GB | 1200GB | 1000GB |
|
||||
| 全参数 | 16 | 160GB | 320GB | 600GB | 1200GB | 900GB |
|
||||
| 部分参数 | 16 | 20GB | 40GB | 120GB | 240GB | 200GB |
|
||||
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB | 120GB |
|
||||
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | 80GB |
|
||||
|
@ -22,6 +22,10 @@ class ModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
|
||||
)
|
||||
resize_vocab: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to resize the tokenizer vocab and the embedding layers."}
|
||||
)
|
||||
split_special_tokens: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}
|
||||
|
@ -8,9 +8,7 @@ from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import count_parameters, get_current_device, try_download_model_from_ms
|
||||
from llmtuner.model.adapter import init_adapter
|
||||
from llmtuner.model.patcher import patch_config, patch_tokenizer, patch_model, patch_valuehead_model
|
||||
from llmtuner.model.utils import (
|
||||
load_valuehead_params, prepare_model_for_training, resize_embedding_layer, register_autoclass
|
||||
)
|
||||
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training, register_autoclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
@ -94,10 +92,8 @@ def load_model_and_tokenizer(
|
||||
)
|
||||
|
||||
model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model
|
||||
patch_model(model)
|
||||
patch_model(model, tokenizer, model_args)
|
||||
register_autoclass(config, model, tokenizer)
|
||||
if not is_deepspeed_zero3_enabled():
|
||||
resize_embedding_layer(model, tokenizer)
|
||||
|
||||
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
|
||||
model = init_adapter(model, model_args, finetuning_args, is_trainable)
|
||||
|
@ -25,105 +25,34 @@ logger = get_logger(__name__)
|
||||
SUPPORTED_CLASS_FOR_S2ATTN = [] # TODO: add llama
|
||||
|
||||
|
||||
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool):
|
||||
if model_args.rope_scaling is not None:
|
||||
if not hasattr(config, "rope_scaling"):
|
||||
logger.warning("Current model does not support RoPE scaling.")
|
||||
else:
|
||||
if is_trainable:
|
||||
if model_args.rope_scaling == "dynamic":
|
||||
logger.warning(
|
||||
"Dynamic NTK may not work well with fine-tuning. "
|
||||
"See: https://github.com/huggingface/transformers/pull/24653"
|
||||
)
|
||||
|
||||
current_max_length = getattr(config, "max_position_embeddings", None)
|
||||
if current_max_length and model_args.model_max_length > current_max_length:
|
||||
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
||||
else:
|
||||
logger.warning("Input length is smaller than max length. Consider increase input length.")
|
||||
scaling_factor = 1.0
|
||||
else:
|
||||
scaling_factor = 2.0
|
||||
|
||||
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
||||
logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
|
||||
model_args.rope_scaling, scaling_factor
|
||||
))
|
||||
def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int):
|
||||
embedding_dim = embed_weight.size(1)
|
||||
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
|
||||
noise_weight = torch.empty_like(avg_weight[-num_new_tokens:])
|
||||
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
|
||||
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
|
||||
|
||||
|
||||
def _configure_flashattn(model_args: "ModelArguments", config_kwargs: Dict[str, Any]):
|
||||
if model_args.flash_attn and is_flash_attn2_available():
|
||||
config_kwargs["use_flash_attention_2"] = True
|
||||
config_kwargs["torch_dtype"] = model_args.compute_dtype
|
||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||
|
||||
|
||||
def _configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool):
|
||||
if is_trainable and model_args.shift_attn:
|
||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
||||
setattr(config, "group_size_ratio", 0.25)
|
||||
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
||||
else:
|
||||
logger.warning("Current model does not support shift short attention.")
|
||||
|
||||
|
||||
def _configure_quantization(
|
||||
config: "PretrainedConfig",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
config_kwargs: Dict[str, Any]
|
||||
):
|
||||
def _resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
|
||||
r"""
|
||||
Priority: Pre-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
||||
Resize token embeddings.
|
||||
"""
|
||||
if getattr(config, "quantization_config", None): # gptq or awq
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||
current_embedding_size = model.get_input_embeddings().weight.size(0)
|
||||
if len(tokenizer) > current_embedding_size:
|
||||
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
|
||||
logger.warning("Current model does not support resizing token embeddings.")
|
||||
return
|
||||
|
||||
config_kwargs["device_map"] = {"": get_current_device()}
|
||||
quantization_config = getattr(config, "quantization_config", None)
|
||||
logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1)))
|
||||
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
|
||||
new_embedding_size = model.get_input_embeddings().weight.size(0)
|
||||
num_new_tokens = new_embedding_size - current_embedding_size
|
||||
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
|
||||
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
|
||||
|
||||
elif model_args.export_quantization_bit is not None: # gptq
|
||||
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
||||
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
||||
from accelerate.utils import get_max_memory
|
||||
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
raise ValueError("ChatGLM model is not supported.")
|
||||
|
||||
config_kwargs["quantization_config"] = GPTQConfig(
|
||||
bits=model_args.export_quantization_bit,
|
||||
tokenizer=tokenizer,
|
||||
dataset=get_quantization_dataset(tokenizer, model_args)
|
||||
)
|
||||
config_kwargs["device_map"] = "auto"
|
||||
config_kwargs["max_memory"] = get_max_memory()
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
|
||||
|
||||
elif model_args.quantization_bit is not None: # bnb
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
elif model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||
bnb_4bit_quant_type=model_args.quantization_type
|
||||
)
|
||||
|
||||
config_kwargs["device_map"] = {"": get_current_device()}
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
|
||||
|
||||
|
||||
def get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
|
||||
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
|
||||
r"""
|
||||
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
|
||||
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
|
||||
@ -153,7 +82,105 @@ def get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mode
|
||||
return samples
|
||||
|
||||
|
||||
def patch_tokenizer(tokenizer: "PreTrainedTokenizer"):
|
||||
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if model_args.rope_scaling is not None:
|
||||
if not hasattr(config, "rope_scaling"):
|
||||
logger.warning("Current model does not support RoPE scaling.")
|
||||
else:
|
||||
if is_trainable:
|
||||
if model_args.rope_scaling == "dynamic":
|
||||
logger.warning(
|
||||
"Dynamic NTK scaling may not work well with fine-tuning. "
|
||||
"See: https://github.com/huggingface/transformers/pull/24653"
|
||||
)
|
||||
|
||||
current_max_length = getattr(config, "max_position_embeddings", None)
|
||||
if current_max_length and model_args.model_max_length > current_max_length:
|
||||
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
||||
else:
|
||||
logger.warning("Input length is smaller than max length. Consider increase input length.")
|
||||
scaling_factor = 1.0
|
||||
else:
|
||||
scaling_factor = 2.0
|
||||
|
||||
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
||||
logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
|
||||
model_args.rope_scaling, scaling_factor
|
||||
))
|
||||
|
||||
|
||||
def _configure_flashattn(model_args: "ModelArguments", config_kwargs: Dict[str, Any]) -> None:
|
||||
if model_args.flash_attn and is_flash_attn2_available():
|
||||
config_kwargs["use_flash_attention_2"] = True
|
||||
config_kwargs["torch_dtype"] = model_args.compute_dtype
|
||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||
|
||||
|
||||
def _configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if is_trainable and model_args.shift_attn:
|
||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
||||
setattr(config, "group_size_ratio", 0.25)
|
||||
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
||||
else:
|
||||
logger.warning("Current model does not support shift short attention.")
|
||||
|
||||
|
||||
def _configure_quantization(
|
||||
config: "PretrainedConfig",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
config_kwargs: Dict[str, Any]
|
||||
) -> None:
|
||||
r"""
|
||||
Priority: Pre-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
||||
"""
|
||||
if getattr(config, "quantization_config", None): # gptq or awq
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||
|
||||
config_kwargs["device_map"] = {"": get_current_device()}
|
||||
quantization_config = getattr(config, "quantization_config", None)
|
||||
logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1)))
|
||||
|
||||
elif model_args.export_quantization_bit is not None: # gptq
|
||||
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
||||
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
||||
from accelerate.utils import get_max_memory
|
||||
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
raise ValueError("ChatGLM model is not supported.")
|
||||
|
||||
config_kwargs["quantization_config"] = GPTQConfig(
|
||||
bits=model_args.export_quantization_bit,
|
||||
tokenizer=tokenizer,
|
||||
dataset=_get_quantization_dataset(tokenizer, model_args)
|
||||
)
|
||||
config_kwargs["device_map"] = "auto"
|
||||
config_kwargs["max_memory"] = get_max_memory()
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
|
||||
|
||||
elif model_args.quantization_bit is not None: # bnb
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
elif model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||
bnb_4bit_quant_type=model_args.quantization_type
|
||||
)
|
||||
|
||||
config_kwargs["device_map"] = {"": get_current_device()}
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
|
||||
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
|
||||
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
|
||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||
|
||||
@ -164,7 +191,7 @@ def patch_config(
|
||||
model_args: "ModelArguments",
|
||||
config_kwargs: Dict[str, Any],
|
||||
is_trainable: bool
|
||||
):
|
||||
) -> None:
|
||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
setattr(config, "torch_dtype", model_args.compute_dtype)
|
||||
@ -179,7 +206,7 @@ def patch_config(
|
||||
_configure_quantization(config, tokenizer, model_args, config_kwargs)
|
||||
|
||||
|
||||
def patch_model(model: "PreTrainedModel"):
|
||||
def patch_model(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None:
|
||||
if "GenerationMixin" not in str(model.generate.__func__):
|
||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||
|
||||
@ -187,8 +214,13 @@ def patch_model(model: "PreTrainedModel"):
|
||||
setattr(model, "lm_head", model.transformer.output_layer)
|
||||
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
|
||||
|
||||
if model_args.resize_vocab:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with vocab resizing.")
|
||||
|
||||
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead"):
|
||||
_resize_embedding_layer(model, tokenizer)
|
||||
|
||||
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
|
||||
def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None:
|
||||
if isinstance(self.pretrained_model, PreTrainedModel):
|
||||
self.pretrained_model.tie_weights()
|
||||
|
@ -123,14 +123,6 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
|
||||
return None
|
||||
|
||||
|
||||
def noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int):
|
||||
embedding_dim = embed_weight.size(1)
|
||||
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
|
||||
noise_weight = torch.empty_like(avg_weight[-num_new_tokens:])
|
||||
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
|
||||
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
|
||||
|
||||
|
||||
def prepare_model_for_training(
|
||||
model: "PreTrainedModel",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
@ -176,25 +168,6 @@ def prepare_model_for_training(
|
||||
return model
|
||||
|
||||
|
||||
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
|
||||
r"""
|
||||
Resize token embeddings.
|
||||
"""
|
||||
current_embedding_size = model.get_input_embeddings().weight.size(0)
|
||||
if len(tokenizer) > current_embedding_size:
|
||||
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
|
||||
logger.warning("Current model does not support resizing token embeddings.")
|
||||
return
|
||||
|
||||
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
|
||||
new_embedding_size = model.get_input_embeddings().weight.size(0)
|
||||
num_new_tokens = new_embedding_size - current_embedding_size
|
||||
noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
|
||||
noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
|
||||
|
||||
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
|
||||
|
||||
|
||||
def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"):
|
||||
if "AutoConfig" in getattr(config, "auto_map", {}):
|
||||
config.__class__.register_for_auto_class()
|
||||
|
Loading…
x
Reference in New Issue
Block a user