[breaking] support transformers 4.48 (#6628)

Former-commit-id: f154ab175c513a4d7bb866bf2cffc34b77b50508
This commit is contained in:
hoshi-hiyouga
2025-01-31 01:36:33 +08:00
committed by GitHub
parent e71737351f
commit 222423bcef
17 changed files with 53 additions and 105 deletions

View File

@@ -29,7 +29,7 @@ from trl.trainer import disable_dropout_in_model
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach
@@ -282,19 +282,12 @@ class CustomDPOTrainer(DPOTrainer):
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
Subclass and override to accept extra kwargs.
"""
loss = super().compute_loss(model, inputs, return_outputs)
if is_transformers_version_equal_to_4_46() and kwargs.get("num_items_in_batch"):
if return_outputs:
loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
loss = loss / self.args.gradient_accumulation_steps
return loss
return super().compute_loss(model, inputs, return_outputs)
@override
def log(self, logs: Dict[str, float]) -> None:
def log(self, logs: Dict[str, float], *args, **kwargs) -> None:
r"""
Log `logs` on the various objects watching training, including stored metrics.
"""
@@ -318,4 +311,4 @@ class CustomDPOTrainer(DPOTrainer):
if not key.startswith("dummy_"):
logs[key] = metric
return Trainer.log(self, logs)
return Trainer.log(self, logs, *args, **kwargs)