Former-commit-id: 4a6ca621c0
This commit is contained in:
hiyouga
2024-04-01 22:53:52 +08:00
parent 8d987b7af7
commit 829cf6458a
4 changed files with 23 additions and 15 deletions

View File

@@ -353,7 +353,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
batch = self.prepare_model_inputs(queries, responses)
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False)
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
values = torch.transpose(values, 0, 1)