[feature] add support for dft loss (#8917)

This commit is contained in:
XLXW
2025-08-15 23:29:57 +08:00
committed by GitHub
parent 86af92ed56
commit 3cff2fd946
4 changed files with 97 additions and 0 deletions

View File

@@ -78,6 +78,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
if finetuning_args.use_dft_loss:
from ..trainer_utils import dft_loss_func
self.compute_loss_func = dft_loss_func
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: