fix ppo args

This commit is contained in:
hiyouga
2023-10-11 23:40:50 +08:00
parent 2818af0b09
commit 11bd271364
4 changed files with 18 additions and 9 deletions

View File

@@ -206,7 +206,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
replace_model(unwrapped_model, target="reward")
batch = self.prepare_model_inputs(queries, responses)
with torch.cuda.amp.autocast(dtype=self.compute_dtype): # support bf16
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
_, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
@@ -251,7 +251,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
input_ids = input_kwargs["input_ids"]
attention_mask = input_kwargs["attention_mask"]
with torch.cuda.amp.autocast(dtype=self.compute_dtype): # support bf16
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
logits, _, values = model(**input_kwargs)
if values.size(0) != input_ids.size(0): # adapt to chatglm2