fix layer norm dtype

This commit is contained in:
hiyouga
2023-09-28 00:25:55 +08:00
parent b0b0138e1d
commit 84b7486885
6 changed files with 28 additions and 22 deletions

View File

@@ -31,6 +31,7 @@ def find_all_linear_modules(
def prepare_model_for_training(
model: "PreTrainedModel",
layernorm_dtype: torch.dtype,
finetuning_type: str,
output_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True,
@@ -45,7 +46,7 @@ def prepare_model_for_training(
"""
for name, param in model.named_parameters():
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
param.data = param.data.to(torch.float32)
param.data = param.data.to(layernorm_dtype)
if use_gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):