mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 20:00:36 +08:00
fix layer norm dtype
This commit is contained in:
@@ -31,6 +31,7 @@ def find_all_linear_modules(
|
||||
|
||||
def prepare_model_for_training(
|
||||
model: "PreTrainedModel",
|
||||
layernorm_dtype: torch.dtype,
|
||||
finetuning_type: str,
|
||||
output_layer_name: Optional[str] = "lm_head",
|
||||
use_gradient_checkpointing: Optional[bool] = True,
|
||||
@@ -45,7 +46,7 @@ def prepare_model_for_training(
|
||||
"""
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
||||
param.data = param.data.to(torch.float32)
|
||||
param.data = param.data.to(layernorm_dtype)
|
||||
|
||||
if use_gradient_checkpointing:
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
|
||||
Reference in New Issue
Block a user