mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 20:00:36 +08:00
fix #6448
This commit is contained in:
@@ -77,6 +77,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
self.ftx_gamma = finetuning_args.pref_ftx
|
||||
|
||||
Trainer.__init__(self, model=model, **kwargs)
|
||||
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
|
||||
if not hasattr(self, "accelerator"):
|
||||
raise AttributeError("Please update `transformers`.")
|
||||
|
||||
@@ -252,15 +253,14 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||
r"""
|
||||
Fixes the loss value for transformers 4.46.0.
|
||||
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
|
||||
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
|
||||
"""
|
||||
loss = super().compute_loss(model, inputs, return_outputs)
|
||||
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
|
||||
if is_transformers_version_equal_to_4_46() and kwargs.get("num_items_in_batch"):
|
||||
if return_outputs:
|
||||
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
|
||||
loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
|
||||
else:
|
||||
return loss / self.args.gradient_accumulation_steps
|
||||
loss = loss / self.args.gradient_accumulation_steps
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
Reference in New Issue
Block a user