mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 13:42:51 +08:00
tiny fix
Former-commit-id: f8d8690bf4c2981f3151b4ccf07daeb4f3cd38a9
This commit is contained in:
parent
4f3c89a6eb
commit
ca9468ff04
@ -32,10 +32,11 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
|
||||
r"""
|
||||
Replaces the default/reward modules in the model. The model is already unwrapped.
|
||||
"""
|
||||
v_head_layer = model.v_head.summary
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed # type: ignore
|
||||
|
||||
params = [model.v_head.summary.weight, model.v_head.summary.bias]
|
||||
params = [v_head_layer.weight, v_head_layer.bias]
|
||||
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
|
||||
else:
|
||||
context_maybe_zero3 = nullcontext()
|
||||
@ -43,14 +44,12 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
|
||||
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
||||
with context_maybe_zero3:
|
||||
if target == "reward": # save default head temporarily
|
||||
setattr(model, "default_head_weight", model.v_head.summary.weight.data.detach().clone())
|
||||
setattr(model, "default_head_bias", model.v_head.summary.bias.data.detach().clone())
|
||||
setattr(model, "default_head_weight", v_head_layer.weight.data.detach().clone())
|
||||
setattr(model, "default_head_bias", v_head_layer.bias.data.detach().clone())
|
||||
|
||||
device = model.v_head.summary.weight.device
|
||||
model.v_head.summary.weight.data = (
|
||||
model.get_buffer("{}_head_weight".format(target)).detach().clone().to(device)
|
||||
)
|
||||
model.v_head.summary.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device)
|
||||
device = v_head_layer.weight.device
|
||||
v_head_layer.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone().to(device)
|
||||
v_head_layer.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device)
|
||||
|
||||
|
||||
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
|
||||
|
Loading…
x
Reference in New Issue
Block a user