mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
fix dpo trainer
Former-commit-id: 074745b1707f98e092749f57041d866c5d55bc04
This commit is contained in:
parent
bf872424da
commit
938c4cb132
@ -130,6 +130,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
||||||
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
||||||
|
|
||||||
|
if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
|
||||||
|
raise ValueError("Unsloth does not support lora reward model.")
|
||||||
|
|
||||||
if training_args.max_steps == -1 and data_args.streaming:
|
if training_args.max_steps == -1 and data_args.streaming:
|
||||||
raise ValueError("Please specify `max_steps` in streaming mode.")
|
raise ValueError("Please specify `max_steps` in streaming mode.")
|
||||||
|
|
||||||
|
@ -69,7 +69,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
Returns:
|
Returns:
|
||||||
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
|
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
|
||||||
"""
|
"""
|
||||||
all_logps = self._get_batch_logps(
|
all_logps = self.get_batch_logps(
|
||||||
chosen_logits,
|
chosen_logits,
|
||||||
chosen_labels,
|
chosen_labels,
|
||||||
average_log_prob=True
|
average_log_prob=True
|
||||||
@ -89,7 +89,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
return_dict=True
|
return_dict=True
|
||||||
).logits.to(torch.float32)
|
).logits.to(torch.float32)
|
||||||
|
|
||||||
all_logps = self._get_batch_logps(
|
all_logps = self.get_batch_logps(
|
||||||
all_logits,
|
all_logits,
|
||||||
batch["labels"],
|
batch["labels"],
|
||||||
average_log_prob=False
|
average_log_prob=False
|
||||||
|
Loading…
x
Reference in New Issue
Block a user