mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-23 23:30:36 +08:00
[breaking] support transformers 4.48 (#6628)
Former-commit-id: f154ab175c513a4d7bb866bf2cffc34b77b50508
This commit is contained in:
@@ -25,7 +25,7 @@ from transformers import Trainer
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||
|
||||
@@ -107,10 +107,6 @@ class PairwiseTrainer(Trainer):
|
||||
chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze()
|
||||
|
||||
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
|
||||
|
||||
if is_transformers_version_equal_to_4_46() and kwargs.get("num_items_in_batch"):
|
||||
loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0-4.46.1
|
||||
|
||||
if return_outputs:
|
||||
return loss, (loss, chosen_scores, rejected_scores)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user