mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-14 19:06:26 +08:00
@@ -203,7 +203,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
r"""
|
||||
Generates model's responses given queries.
|
||||
"""
|
||||
if self.finetuning_args.upcast_layernorm:
|
||||
if self.model_args.upcast_layernorm:
|
||||
layernorm_params = dump_layernorm(self.model)
|
||||
|
||||
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
|
||||
@@ -218,7 +218,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
**batch
|
||||
)
|
||||
|
||||
if self.finetuning_args.upcast_layernorm:
|
||||
if self.model_args.upcast_layernorm:
|
||||
restore_layernorm(self.model, layernorm_params)
|
||||
|
||||
query = batch["input_ids"].detach().cpu()
|
||||
|
||||
Reference in New Issue
Block a user