fix ppo in trl 0.8.6

This commit is contained in:
hiyouga
2024-06-07 04:48:29 +08:00
parent f9e818d79c
commit 2702d7e952
4 changed files with 37 additions and 25 deletions

View File

@@ -1,7 +1,9 @@
import json
from contextlib import nullcontext
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras.packages import is_requests_available
@@ -28,16 +30,27 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
r"""
Replaces the default/reward modules in the model. The model is already unwrapped (and gathered).
Replaces the default/reward modules in the model. The model is already unwrapped.
"""
if target == "reward": # save default head temporarily
setattr(model, "default_head_weight", model.v_head.summary.weight.data.detach().clone())
setattr(model, "default_head_bias", model.v_head.summary.bias.data.detach().clone())
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
params = [model.v_head.summary.weight, model.v_head.summary.bias]
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
else:
context_maybe_zero3 = nullcontext()
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
device = model.v_head.summary.weight.device
model.v_head.summary.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone().to(device)
model.v_head.summary.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device)
with context_maybe_zero3:
if target == "reward": # save default head temporarily
setattr(model, "default_head_weight", model.v_head.summary.weight.data.detach().clone())
setattr(model, "default_head_bias", model.v_head.summary.bias.data.detach().clone())
device = model.v_head.summary.weight.device
model.v_head.summary.weight.data = (
model.get_buffer("{}_head_weight".format(target)).detach().clone().to(device)
)
model.v_head.summary.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device)
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]: