From 265dc1b6a0fc2d607c2ba2509f45e76b65e77779 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 5 Jul 2023 19:14:10 +0800 Subject: [PATCH] fix bug in PPO stage Former-commit-id: a2ba69183b8e72c09242317a34545ab966ea8991 --- src/utils/ppo.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/utils/ppo.py b/src/utils/ppo.py index 701d4b4b..477dce59 100644 --- a/src/utils/ppo.py +++ b/src/utils/ppo.py @@ -25,11 +25,10 @@ logger = get_logger(__name__) def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None: - if target == "reward": # save original head temporarily + if target == "reward": # save default head temporarily valuehead_state_dict = model.v_head.state_dict() - - setattr(model, "origin_head_weight", valuehead_state_dict["summary.weight"]) - setattr(model, "origin_head_bias", valuehead_state_dict["summary.bias"]) + setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"]) + setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"]) model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active model.v_head.load_state_dict({ @@ -40,7 +39,7 @@ def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["def def cast_layernorm_dtype( model: AutoModelForCausalLMWithValueHead, - layer_norm_names: List[str] = ["norm", "ln_f"], # for LLaMA and BLOOM setting + layer_norm_names: List[str] = ["norm", "ln_f", "ln_attn", "ln_mlp"], # for LLaMA, BLOOM and Falcon settings layer_norm_params: Optional[Dict[str, torch.Tensor]] = None ) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]: