mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 22:02:51 +08:00
fix bug in PPO stage
Former-commit-id: a2ba69183b8e72c09242317a34545ab966ea8991
This commit is contained in:
parent
7ba52f5b6e
commit
265dc1b6a0
@ -25,11 +25,10 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None:
|
||||
if target == "reward": # save original head temporarily
|
||||
if target == "reward": # save default head temporarily
|
||||
valuehead_state_dict = model.v_head.state_dict()
|
||||
|
||||
setattr(model, "origin_head_weight", valuehead_state_dict["summary.weight"])
|
||||
setattr(model, "origin_head_bias", valuehead_state_dict["summary.bias"])
|
||||
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"])
|
||||
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"])
|
||||
|
||||
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
||||
model.v_head.load_state_dict({
|
||||
@ -40,7 +39,7 @@ def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["def
|
||||
|
||||
def cast_layernorm_dtype(
|
||||
model: AutoModelForCausalLMWithValueHead,
|
||||
layer_norm_names: List[str] = ["norm", "ln_f"], # for LLaMA and BLOOM setting
|
||||
layer_norm_names: List[str] = ["norm", "ln_f", "ln_attn", "ln_mlp"], # for LLaMA, BLOOM and Falcon settings
|
||||
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]:
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user