Former-commit-id: 1585962eb7ed042890d4c56422aae749c669dda8
This commit is contained in:
hiyouga 2023-11-30 21:02:00 +08:00
parent bb6b4823ad
commit ba6d290d0b
2 changed files with 4 additions and 2 deletions

View File

@ -298,7 +298,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16 with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
logits, _, values = model(**input_kwargs) logits, _, values = model(**input_kwargs)
if getattr(model.config, "model_type", None) == "chatglm": unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
values = torch.transpose(values, 0, 1) values = torch.transpose(values, 0, 1)
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])

View File

@ -40,7 +40,8 @@ class PairwiseTrainer(Trainer):
# Compute rewards # Compute rewards
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True) _, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
if getattr(model.config, "model_type", None) == "chatglm": unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
values = torch.transpose(values, 0, 1) values = torch.transpose(values, 0, 1)
# Split the inputs and rewards into two parts, chosen and rejected # Split the inputs and rewards into two parts, chosen and rejected