mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-23 23:30:36 +08:00
@@ -52,6 +52,7 @@ class PairwiseTrainer(Trainer):
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
|
||||
self.finetuning_args = finetuning_args
|
||||
self.can_return_loss = True # override property to return eval_loss
|
||||
self.add_callback(FixValueHeadModelCallback)
|
||||
@@ -107,8 +108,8 @@ class PairwiseTrainer(Trainer):
|
||||
|
||||
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
|
||||
|
||||
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
|
||||
loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0
|
||||
if is_transformers_version_equal_to_4_46() and kwargs.get("num_items_in_batch"):
|
||||
loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0-4.46.1
|
||||
|
||||
if return_outputs:
|
||||
return loss, (loss, chosen_scores, rejected_scores)
|
||||
|
||||
Reference in New Issue
Block a user