fix bug in PPO stage

Former-commit-id: a2ba69183b8e72c09242317a34545ab966ea8991
This commit is contained in:
hiyouga 2023-07-05 19:14:10 +08:00
parent 7ba52f5b6e
commit 265dc1b6a0

View File

@ -25,11 +25,10 @@ logger = get_logger(__name__)
def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None:
if target == "reward": # save original head temporarily
if target == "reward": # save default head temporarily
valuehead_state_dict = model.v_head.state_dict()
setattr(model, "origin_head_weight", valuehead_state_dict["summary.weight"])
setattr(model, "origin_head_bias", valuehead_state_dict["summary.bias"])
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"])
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"])
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
model.v_head.load_state_dict({
@ -40,7 +39,7 @@ def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["def
def cast_layernorm_dtype(
model: AutoModelForCausalLMWithValueHead,
layer_norm_names: List[str] = ["norm", "ln_f"], # for LLaMA and BLOOM setting
layer_norm_names: List[str] = ["norm", "ln_f", "ln_attn", "ln_mlp"], # for LLaMA, BLOOM and Falcon settings
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]: