refactor model_dtype, fix PPO trainer

This commit is contained in:
hiyouga
2023-10-11 23:16:01 +08:00
parent 5310e4d182
commit 2818af0b09
10 changed files with 104 additions and 119 deletions

View File

@@ -31,11 +31,11 @@ def find_all_linear_modules(
def prepare_model_for_training(
model: "PreTrainedModel",
layernorm_dtype: torch.dtype,
upcast_layernorm: bool,
finetuning_type: str,
output_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True,
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
layernorm_names: Optional[List[str]] = LAYERNORM_NAMES
) -> "PreTrainedModel":
r"""
Includes:
@@ -44,9 +44,10 @@ def prepare_model_for_training(
(3) upcast the lm_head to fp32
Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33
"""
for name, param in model.named_parameters():
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
param.data = param.data.to(layernorm_dtype)
if upcast_layernorm:
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in layernorm_names):
param.data = param.data.to(torch.float32)
if use_gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):