fix ppo trainer

Former-commit-id: 5431be42f9c43095d478f2250fac64ef189eb3ad
This commit is contained in:
hiyouga 2023-12-28 18:09:28 +08:00
parent 024b0b1ab2
commit d0946f08db

View File

@ -203,7 +203,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
r"""
Generates model's responses given queries.
"""
if self.finetuning_args.upcast_layernorm:
if self.model_args.upcast_layernorm:
layernorm_params = dump_layernorm(self.model)
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
@ -218,7 +218,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
**batch
)
if self.finetuning_args.upcast_layernorm:
if self.model_args.upcast_layernorm:
restore_layernorm(self.model, layernorm_params)
query = batch["input_ids"].detach().cpu()