diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index 356de716..15703e3b 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -70,6 +70,10 @@ class ModelArguments: default=False, metadata={"help": "Whether or not to upcast the layernorm weights in fp32."} ) + upcast_lmhead_output: Optional[bool] = field( + default=False, + metadata={"help": "Whether or not to upcast the output of lm_head in fp32."} + ) hf_hub_token: Optional[str] = field( default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."} diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 5ce8d604..9a8680fb 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -209,7 +209,7 @@ def _prepare_model_for_training( model.config.use_cache = False # turn off when gradient checkpointing is enabled logger.info("Gradient checkpointing enabled.") - if hasattr(model, output_layer_name): + if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output: def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor): return output.to(torch.float32)