mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-02 09:46:00 +08:00
[algo] add ASFT (#10174)
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
@@ -52,6 +53,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
processor: Optional["ProcessorMixin"],
|
||||
model_args: Optional["ModelArguments"] = None,
|
||||
gen_kwargs: Optional[dict[str, Any]] = None,
|
||||
ref_model: Optional["torch.nn.Module"] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
@@ -82,6 +84,27 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
self.ref_model = ref_model
|
||||
|
||||
if ref_model is not None:
|
||||
from trl.models.utils import prepare_deepspeed, prepare_fsdp
|
||||
|
||||
if getattr(self.accelerator.state, "deepspeed_plugin", None) is not None:
|
||||
if not (
|
||||
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
|
||||
): # quantized models are already set on the correct device
|
||||
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
||||
elif getattr(self.accelerator.state, "fsdp_plugin", None) is not None:
|
||||
if self.accelerator.is_fsdp2:
|
||||
from accelerate.utils.fsdp_utils import fsdp2_prepare_model
|
||||
|
||||
self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model)
|
||||
else:
|
||||
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
self.ref_model.eval()
|
||||
|
||||
if finetuning_args.use_dft_loss:
|
||||
from ..trainer_utils import dft_loss_func
|
||||
|
||||
@@ -93,6 +116,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
self.compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func(
|
||||
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
|
||||
)
|
||||
elif finetuning_args.use_asft_loss:
|
||||
from ..trainer_utils import asft_loss_func
|
||||
|
||||
self.compute_loss_func = partial(
|
||||
asft_loss_func,
|
||||
asft_alpha=finetuning_args.asft_alpha,
|
||||
)
|
||||
|
||||
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
||||
verify_fp8_status(self.accelerator, training_args)
|
||||
@@ -119,7 +149,17 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
|
||||
@override
|
||||
def compute_loss(self, model, inputs, *args, **kwargs):
|
||||
return super().compute_loss(model, inputs, *args, **kwargs)
|
||||
if self.finetuning_args.use_asft_loss:
|
||||
with torch.no_grad():
|
||||
ref_outputs = self.ref_model(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs.get("attention_mask", None),
|
||||
)
|
||||
ref_logits = ref_outputs.logits
|
||||
outputs = model(**inputs)
|
||||
return self.compute_loss_func(outputs, inputs["labels"], ref_logits)
|
||||
else:
|
||||
return super().compute_loss(model, inputs, *args, **kwargs)
|
||||
|
||||
@override
|
||||
def prediction_step(
|
||||
|
||||
Reference in New Issue
Block a user