Former-commit-id: 04f78e85af5af14b4c195936623e426a6a128af2
This commit is contained in:
hiyouga
2024-12-27 16:54:39 +00:00
parent 5769a553d2
commit 88b1874c04
7 changed files with 29 additions and 27 deletions

View File

@@ -52,6 +52,7 @@ class PairwiseTrainer(Trainer):
kwargs["processing_class"] = kwargs.pop("tokenizer")
super().__init__(**kwargs)
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
self.finetuning_args = finetuning_args
self.can_return_loss = True # override property to return eval_loss
self.add_callback(FixValueHeadModelCallback)
@@ -107,8 +108,8 @@ class PairwiseTrainer(Trainer):
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0
if is_transformers_version_equal_to_4_46() and kwargs.get("num_items_in_batch"):
loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0-4.46.1
if return_outputs:
return loss, (loss, chosen_scores, rejected_scores)