mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 04:32:50 +08:00
add upcast_lmhead option
Former-commit-id: 8cbe4e960983ace47f8b956cdc31411347592129
This commit is contained in:
parent
ffde3d94bf
commit
edf725208d
@ -70,6 +70,10 @@ class ModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to upcast the layernorm weights in fp32."}
|
||||
)
|
||||
upcast_lmhead_output: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to upcast the output of lm_head in fp32."}
|
||||
)
|
||||
hf_hub_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
||||
|
@ -209,7 +209,7 @@ def _prepare_model_for_training(
|
||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
||||
logger.info("Gradient checkpointing enabled.")
|
||||
|
||||
if hasattr(model, output_layer_name):
|
||||
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
|
||||
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
||||
return output.to(torch.float32)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user