From 675ce8cc7f70a65de403ccfd05195ca3ea6f3bd4 Mon Sep 17 00:00:00 2001 From: Junyou Su <104712257+susjunyou@users.noreply.github.com> Date: Thu, 12 Feb 2026 13:12:14 +0800 Subject: [PATCH] [algo] add ASFT (#10174) --- examples/extras/asft/llama2_full_asft.yaml | 45 +++++++++++ examples/extras/asft/qwen2_full_asft.yaml | 45 +++++++++++ src/llamafactory/hparams/finetuning_args.py | 8 ++ src/llamafactory/train/sft/trainer.py | 42 ++++++++++- src/llamafactory/train/sft/workflow.py | 7 +- src/llamafactory/train/trainer_utils.py | 83 +++++++++++++++++++++ 6 files changed, 228 insertions(+), 2 deletions(-) create mode 100644 examples/extras/asft/llama2_full_asft.yaml create mode 100644 examples/extras/asft/qwen2_full_asft.yaml diff --git a/examples/extras/asft/llama2_full_asft.yaml b/examples/extras/asft/llama2_full_asft.yaml new file mode 100644 index 000000000..fb1d7f128 --- /dev/null +++ b/examples/extras/asft/llama2_full_asft.yaml @@ -0,0 +1,45 @@ +### model +model_name_or_path: models/Llama-2-7b +trust_remote_code: true + +### method +stage: sft +do_train: true +finetuning_type: full +deepspeed: examples/deepspeed/ds_z0_config.json +use_asft_loss: true +asft_alpha: 0.1 + +### dataset +dataset: med +template: llama2 +cutoff_len: 2048 +max_samples: 10000 +overwrite_cache: true +preprocessing_num_workers: 16 +dataloader_num_workers: 4 + +### output +output_dir: saves/llama2-7b/full/asft2 +logging_steps: 1 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true +save_only_model: false +report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow] + +### train +per_device_train_batch_size: 4 +gradient_accumulation_steps: 8 +learning_rate: 2.0e-5 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +bf16: true +ddp_timeout: 180000000 + +### eval +# val_size: 0.1 +# per_device_eval_batch_size: 1 +# eval_strategy: steps +# eval_steps: 500 diff --git a/examples/extras/asft/qwen2_full_asft.yaml b/examples/extras/asft/qwen2_full_asft.yaml new file mode 100644 index 000000000..76fd52449 --- /dev/null +++ b/examples/extras/asft/qwen2_full_asft.yaml @@ -0,0 +1,45 @@ +### model +model_name_or_path: models/Qwen2.5-7B +trust_remote_code: true + +### method +stage: sft +do_train: true +finetuning_type: full +deepspeed: examples/deepspeed/ds_z0_config.json +use_asft_loss: true +asft_alpha: 0.05 + +### dataset +dataset: math +template: qwen +cutoff_len: 2048 +max_samples: 10000 +overwrite_cache: true +preprocessing_num_workers: 16 +dataloader_num_workers: 4 + +### output +output_dir: saves/qwen2-7b/full/asft +logging_steps: 10 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true +save_only_model: false +report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow] + +### train +per_device_train_batch_size: 4 +gradient_accumulation_steps: 8 +learning_rate: 5.0e-5 +num_train_epochs: 1.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +bf16: true +ddp_timeout: 180000000 + +### eval +# val_size: 0.1 +# per_device_eval_batch_size: 1 +# eval_strategy: steps +# eval_steps: 500 diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index c089ca67c..053b4ab6a 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -490,6 +490,14 @@ class FinetuningArguments( default=False, metadata={"help": "Whether to use the DFT loss."}, ) + use_asft_loss: bool = field( + default=False, + metadata={"help": "Whether to use the ASFT loss."}, + ) + asft_alpha: float = field( + default=0.1, + metadata={"help": "The alpha parameter for ASFT loss to control the power of adaptive weight."}, + ) use_eaft_loss: bool = field( default=False, metadata={"help": "Whether to use the EAFT loss."}, diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index e1bfe194f..7726b3954 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -17,6 +17,7 @@ import json import os +from functools import partial from types import MethodType from typing import TYPE_CHECKING, Any, Optional, Union @@ -52,6 +53,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): processor: Optional["ProcessorMixin"], model_args: Optional["ModelArguments"] = None, gen_kwargs: Optional[dict[str, Any]] = None, + ref_model: Optional["torch.nn.Module"] = None, **kwargs, ) -> None: kwargs["processing_class"] = kwargs.pop("tokenizer") @@ -82,6 +84,27 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) + self.ref_model = ref_model + + if ref_model is not None: + from trl.models.utils import prepare_deepspeed, prepare_fsdp + + if getattr(self.accelerator.state, "deepspeed_plugin", None) is not None: + if not ( + getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False) + ): # quantized models are already set on the correct device + self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) + elif getattr(self.accelerator.state, "fsdp_plugin", None) is not None: + if self.accelerator.is_fsdp2: + from accelerate.utils.fsdp_utils import fsdp2_prepare_model + + self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model) + else: + self.ref_model = prepare_fsdp(self.ref_model, self.accelerator) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + self.ref_model.eval() + if finetuning_args.use_dft_loss: from ..trainer_utils import dft_loss_func @@ -93,6 +116,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): self.compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func( outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha ) + elif finetuning_args.use_asft_loss: + from ..trainer_utils import asft_loss_func + + self.compute_loss_func = partial( + asft_loss_func, + asft_alpha=finetuning_args.asft_alpha, + ) if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization verify_fp8_status(self.accelerator, training_args) @@ -119,7 +149,17 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): @override def compute_loss(self, model, inputs, *args, **kwargs): - return super().compute_loss(model, inputs, *args, **kwargs) + if self.finetuning_args.use_asft_loss: + with torch.no_grad(): + ref_outputs = self.ref_model( + input_ids=inputs["input_ids"], + attention_mask=inputs.get("attention_mask", None), + ) + ref_logits = ref_outputs.logits + outputs = model(**inputs) + return self.compute_loss_func(outputs, inputs["labels"], ref_logits) + else: + return super().compute_loss(model, inputs, *args, **kwargs) @override def prediction_step( diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index 1bf14f1eb..db625028f 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -24,7 +24,7 @@ from ...extras.misc import calculate_tps from ...extras.packages import is_transformers_version_greater_than from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer -from ..trainer_utils import create_modelcard_and_push +from ..trainer_utils import create_modelcard_and_push, create_ref_model from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor from .trainer import CustomSeq2SeqTrainer @@ -52,6 +52,10 @@ def run_sft( dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) + ref_model = None + if finetuning_args.use_asft_loss: + ref_model = create_ref_model(model_args, finetuning_args) + if getattr(model, "is_quantized", False) and not training_args.do_train: setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction @@ -124,6 +128,7 @@ def run_sft( data_collator=data_collator, callbacks=callbacks, gen_kwargs=gen_kwargs, + ref_model=ref_model, **dataset_module, **tokenizer_module, **metric_module, diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index e58316092..5f927ddea 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -23,6 +23,7 @@ from collections.abc import Callable, Mapping from typing import TYPE_CHECKING, Any, Optional, Union import torch +import torch.nn.functional as F from transformers import Trainer from transformers.integrations import is_deepspeed_zero3_enabled from transformers.modeling_utils import is_fsdp_enabled @@ -681,6 +682,88 @@ def _dft_cross_entropy( return loss +def asft_loss_func( + outputs, + labels: torch.Tensor, + ref_logits: torch.Tensor, + asft_alpha: float = 0.1, + ignore_index: int = -100, +) -> torch.Tensor: + logits = outputs.get("logits") + if logits is None: + return outputs.get("loss", torch.tensor(0.0)) + + logits = logits.float() + + # shift for causal LM + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + shift_ref_logits = ref_logits[..., :-1, :].contiguous() + + vocab_size = shift_logits.size(-1) + + # flatten + shift_logits = shift_logits.view(-1, vocab_size) + shift_ref_logits = shift_ref_logits.view(-1, vocab_size) + shift_labels = shift_labels.view(-1).to(shift_logits.device) + + return _asft_cross_entropy( + policy_logits=shift_logits, + policy_labels=shift_labels, + ref_logits=shift_ref_logits, + asft_alpha=asft_alpha, + ignore_index=ignore_index, + ) + + +def _asft_cross_entropy( + policy_logits: torch.Tensor, + policy_labels: torch.Tensor, + ref_logits: torch.Tensor, + asft_alpha: float = 0.1, + ignore_index: int = -100, +) -> torch.Tensor: + dft_loss = _dft_cross_entropy( + policy_logits, + policy_labels, + ignore_index=ignore_index, + ) + + kl_loss = _kl_divergence( + policy_logits, + ref_logits, + policy_labels, + ignore_index=ignore_index, + ) + + return dft_loss + asft_alpha * kl_loss + + +def _kl_divergence( + policy_logits: torch.Tensor, + ref_logits: torch.Tensor, + labels: torch.Tensor, + ignore_index: int = -100, +) -> torch.Tensor: + # log p(y|x) + log_p = F.log_softmax(policy_logits, dim=-1) + + # q(y|x) + q = F.softmax(ref_logits, dim=-1) + + # token-wise KL + kl = F.kl_div( + log_p, + q, + reduction="none", + ).sum(dim=-1) # [N] + + # mask padding tokens + mask = (labels != ignore_index).float() + + return (kl * mask).sum() / mask.sum() + + def eaft_loss_func( outputs: "torch.Tensor", labels: "torch.Tensor",