[feature] add support for EAFT loss (#9720)

Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
yanglele
2026-01-06 23:07:12 +08:00
committed by GitHub
parent 68119e5522
commit e944dc442c
4 changed files with 112 additions and 0 deletions

View File

@@ -87,6 +87,15 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.compute_loss_func = dft_loss_func
elif finetuning_args.use_eaft_loss:
from ..trainer_utils import eaft_loss_func
self.compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func(
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
)
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
verify_fp8_status(self.accelerator, training_args)