mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 20:00:36 +08:00
fix incorrect loss value for vlms
This commit is contained in:
@@ -60,7 +60,7 @@ class PairwiseTrainer(Trainer):
|
||||
self.add_callback(PissaConvertCallback)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
@@ -100,7 +100,7 @@ class PairwiseTrainer(Trainer):
|
||||
|
||||
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
|
||||
|
||||
if kwargs.pop("num_items_in_batch", False) and is_transformers_version_equal_to_4_46():
|
||||
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
|
||||
loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0
|
||||
|
||||
if return_outputs:
|
||||
|
||||
Reference in New Issue
Block a user