diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index f0363274..ef881c44 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -2,7 +2,7 @@ IGNORE_INDEX = -100 LOG_FILE_NAME = "trainer_log.jsonl" -LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] +LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp", "ln_1", "ln_2"] METHODS = ["full", "freeze", "lora"] diff --git a/src/llmtuner/extras/patches/llama_patch.py b/src/llmtuner/extras/patches/llama_patch.py index ba4e603c..830c7ce3 100644 --- a/src/llmtuner/extras/patches/llama_patch.py +++ b/src/llmtuner/extras/patches/llama_patch.py @@ -19,21 +19,6 @@ except ImportError: logger = logging.get_logger(__name__) -class LlamaRMSNorm(nn.Module): - - def __init__(self, hidden_size, eps=1e-6): - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return (self.weight * hidden_states).to(input_dtype) - - class LlamaShiftShortAttention(LlamaAttention): def forward( @@ -162,6 +147,14 @@ class LlamaFlashAttention2(LlamaAttention): past_key_value = (key_states, value_states) if use_cache else None + # cast to half precision + input_dtype = query_states.dtype + if input_dtype == torch.float32: + logger.warning_once("The input hidden states seems to be silently casted in float32.") + query_states = query_states.to(torch.float16) + key_states = key_states.to(torch.float16) + value_states = value_states.to(torch.float16) + if getattr(self, "num_key_value_groups"): key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index e5bbc04c..a26f8aa2 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -67,6 +67,10 @@ class ModelArguments: default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."} ) + layernorm_dtype: Optional[Literal["auto", "fp16", "bf16", "fp32"]] = field( + default="auto", + metadata={"help": "Data type of the layer norm weights."} + ) def __post_init__(self): self.compute_dtype = None diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 36525d33..55237d88 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -128,10 +128,6 @@ def load_model_and_tokenizer( else: logger.warning("Current model does not support RoPE scaling.") - # Fix RMSNorm in fp32 weight (https://github.com/huggingface/transformers/pull/23535) - if getattr(config, "model_type", None) == "llama": - LlamaModule.LlamaRMSNorm = LlamaPatches.LlamaRMSNorm - # Set FlashAttention-2 if model_args.flash_attn: if getattr(config, "model_type", None) == "llama": @@ -205,7 +201,8 @@ def load_model_and_tokenizer( tokenizer.__class__.register_for_auto_class() # Initialize adapters - model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model + if is_trainable: + model = prepare_model_for_training(model, model_args.layernorm_dtype, finetuning_args.finetuning_type) model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable) model = model.train() if is_trainable else model.eval() diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py index 46d89cbf..d2d0113a 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/tuner/core/parser.py @@ -226,6 +226,17 @@ def get_train_args( else: model_args.compute_dtype = _infer_dtype() + if model_args.layernorm_dtype == "bf16": + if not is_bf16_available: + raise ValueError("Current device does not support bf16 type.") + model_args.layernorm_dtype = torch.bfloat16 + elif model_args.layernorm_dtype == "fp16": + model_args.layernorm_dtype = torch.float16 + elif model_args.layernorm_dtype == "fp32": + model_args.layernorm_dtype = torch.float32 + else: + model_args.layernorm_dtype = model_args.compute_dtype + model_args.model_max_length = data_args.cutoff_len # Log on each process the small summary: diff --git a/src/llmtuner/tuner/core/utils.py b/src/llmtuner/tuner/core/utils.py index 74ff075f..a07fa31c 100644 --- a/src/llmtuner/tuner/core/utils.py +++ b/src/llmtuner/tuner/core/utils.py @@ -31,6 +31,7 @@ def find_all_linear_modules( def prepare_model_for_training( model: "PreTrainedModel", + layernorm_dtype: torch.dtype, finetuning_type: str, output_layer_name: Optional[str] = "lm_head", use_gradient_checkpointing: Optional[bool] = True, @@ -45,7 +46,7 @@ def prepare_model_for_training( """ for name, param in model.named_parameters(): if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): - param.data = param.data.to(torch.float32) + param.data = param.data.to(layernorm_dtype) if use_gradient_checkpointing: if hasattr(model, "enable_input_require_grads"):