fix incorrect loss value for vlms

Former-commit-id: 30567a1487
This commit is contained in:
hiyouga
2024-10-30 08:56:46 +00:00
parent 1b02915d19
commit 584ce3a105
12 changed files with 48 additions and 22 deletions

View File

@@ -181,7 +181,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)