mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-05 21:22:50 +08:00
parent
ae1048db6d
commit
ecfc7d1b50
@ -250,7 +250,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
reward_model = self.reward_model if self.reward_model is not None else self.model
|
reward_model = self.reward_model if self.reward_model is not None else self.model
|
||||||
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
|
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
|
||||||
|
|
||||||
if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
|
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
|
||||||
values = torch.transpose(values, 0, 1)
|
values = torch.transpose(values, 0, 1)
|
||||||
|
|
||||||
rewards = []
|
rewards = []
|
||||||
@ -298,7 +298,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||||
logits, _, values = model(**input_kwargs)
|
logits, _, values = model(**input_kwargs)
|
||||||
|
|
||||||
if values.size(0) != input_ids.size(0): # adapt to chatglm2
|
if getattr(model.config, "model_type", None) == "chatglm":
|
||||||
values = torch.transpose(values, 0, 1)
|
values = torch.transpose(values, 0, 1)
|
||||||
|
|
||||||
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
|
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
|
||||||
|
@ -39,7 +39,8 @@ class PairwiseTrainer(Trainer):
|
|||||||
"""
|
"""
|
||||||
# Compute rewards
|
# Compute rewards
|
||||||
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
||||||
if values.size(0) != inputs["input_ids"].size(0): # adapt to chatglm2
|
|
||||||
|
if getattr(model.config, "model_type", None) == "chatglm":
|
||||||
values = torch.transpose(values, 0, 1)
|
values = torch.transpose(values, 0, 1)
|
||||||
|
|
||||||
# Split the inputs and rewards into two parts, chosen and rejected
|
# Split the inputs and rewards into two parts, chosen and rejected
|
||||||
|
Loading…
x
Reference in New Issue
Block a user