From edf725208d567eb57a15d41856024d9cf583d720 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 19 Jan 2024 23:54:25 +0800 Subject: [PATCH] add upcast_lmhead option Former-commit-id: 8cbe4e960983ace47f8b956cdc31411347592129 --- src/llmtuner/hparams/model_args.py | 4 ++++ src/llmtuner/model/patcher.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index 356de716..15703e3b 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -70,6 +70,10 @@ class ModelArguments: default=False, metadata={"help": "Whether or not to upcast the layernorm weights in fp32."} ) + upcast_lmhead_output: Optional[bool] = field( + default=False, + metadata={"help": "Whether or not to upcast the output of lm_head in fp32."} + ) hf_hub_token: Optional[str] = field( default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."} diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 5ce8d604..9a8680fb 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -209,7 +209,7 @@ def _prepare_model_for_training( model.config.use_cache = False # turn off when gradient checkpointing is enabled logger.info("Gradient checkpointing enabled.") - if hasattr(model, output_layer_name): + if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output: def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor): return output.to(torch.float32)