add upcast_lmhead option

Former-commit-id: 8cbe4e960983ace47f8b956cdc31411347592129
This commit is contained in:
hiyouga 2024-01-19 23:54:25 +08:00
parent ffde3d94bf
commit edf725208d
2 changed files with 5 additions and 1 deletions

View File

@ -70,6 +70,10 @@ class ModelArguments:
default=False, default=False,
metadata={"help": "Whether or not to upcast the layernorm weights in fp32."} metadata={"help": "Whether or not to upcast the layernorm weights in fp32."}
) )
upcast_lmhead_output: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to upcast the output of lm_head in fp32."}
)
hf_hub_token: Optional[str] = field( hf_hub_token: Optional[str] = field(
default=None, default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."} metadata={"help": "Auth token to log in with Hugging Face Hub."}

View File

@ -209,7 +209,7 @@ def _prepare_model_for_training(
model.config.use_cache = False # turn off when gradient checkpointing is enabled model.config.use_cache = False # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.") logger.info("Gradient checkpointing enabled.")
if hasattr(model, output_layer_name): if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor): def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
return output.to(torch.float32) return output.to(torch.float32)