mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-13 09:30:34 +08:00
[feature] add support for EAFT loss (#9720)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -87,6 +87,15 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
|
||||
self.compute_loss_func = dft_loss_func
|
||||
|
||||
|
||||
elif finetuning_args.use_eaft_loss:
|
||||
from ..trainer_utils import eaft_loss_func
|
||||
|
||||
self.compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func(
|
||||
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
|
||||
)
|
||||
|
||||
|
||||
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
||||
verify_fp8_status(self.accelerator, training_args)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user