From 0454c10456e7a7a56d1e50eedc6c6f4edec0e7b0 Mon Sep 17 00:00:00 2001 From: XLXW Date: Fri, 15 Aug 2025 23:29:57 +0800 Subject: [PATCH] [feature] add support for dft loss (#8917) --- examples/extras/dft/qwen2_full_sft.yaml | 43 ++++++++++++++++++++ src/llamafactory/hparams/finetuning_args.py | 4 ++ src/llamafactory/train/sft/trainer.py | 5 +++ src/llamafactory/train/trainer_utils.py | 45 +++++++++++++++++++++ 4 files changed, 97 insertions(+) create mode 100644 examples/extras/dft/qwen2_full_sft.yaml diff --git a/examples/extras/dft/qwen2_full_sft.yaml b/examples/extras/dft/qwen2_full_sft.yaml new file mode 100644 index 00000000..865cc856 --- /dev/null +++ b/examples/extras/dft/qwen2_full_sft.yaml @@ -0,0 +1,43 @@ +### model +model_name_or_path: Qwen/Qwen2-1.5B-Instruct +trust_remote_code: true + +### method +stage: sft +do_train: true +finetuning_type: full +use_dft_loss: true + +### dataset +dataset: identity,alpaca_en_demo +template: qwen +cutoff_len: 2048 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 +dataloader_num_workers: 4 + +### output +output_dir: saves/qwen2-1_5b/full/sft +logging_steps: 10 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true +save_only_model: false +report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow] + +### train +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 +learning_rate: 1.0e-5 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +bf16: true +ddp_timeout: 180000000 + +### eval +# val_size: 0.1 +# per_device_eval_batch_size: 1 +# eval_strategy: steps +# eval_steps: 500 diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 0a3a2f39..21cf30a1 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -428,6 +428,10 @@ class FinetuningArguments( default=False, metadata={"help": "Whether or not to use the Muon optimizer."}, ) + use_dft_loss: bool = field( + default=False, + metadata={"help": "Whether to use the DFT loss."}, + ) freeze_vision_tower: bool = field( default=True, metadata={"help": "Whether ot not to freeze the vision tower in MLLM training."}, diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index d0b8d05b..d378a3a3 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -78,6 +78,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) + if finetuning_args.use_dft_loss: + from ..trainer_utils import dft_loss_func + + self.compute_loss_func = dft_loss_func + @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 7d35bbeb..80a46397 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -631,6 +631,51 @@ def get_batch_logps( return logps, valid_length +def dft_loss_func(outputs, labels, num_items_in_batch=None): + logits = outputs.get("logits") + if logits is None: + return outputs.get("loss", torch.tensor(0.0)) + + logits = logits.float() + vocab_size = logits.size(-1) + labels = torch.nn.functional.pad(labels, (0, 1), value=-100) + shift_labels = labels[..., 1:].contiguous() + logits = logits.view(-1, vocab_size) + shift_labels = shift_labels.view(-1) + shift_labels = shift_labels.to(logits.device) + + loss = _dft_cross_entropy(logits, shift_labels, num_items_in_batch) + return loss + + +def _dft_cross_entropy( + source: torch.Tensor, + target: torch.Tensor, + num_items_in_batch: Optional[torch.Tensor] = None, + ignore_index: int = -100, +) -> torch.Tensor: + per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none") + valid_mask = target != ignore_index + if not valid_mask.any(): + return torch.tensor(0.0, device=source.device, dtype=source.dtype) + + valid_losses = per_token_loss[valid_mask] + + with torch.no_grad(): + target_probs = torch.exp(-valid_losses) + + weighted_losses = valid_losses * target_probs + + if num_items_in_batch is not None: + total_loss = weighted_losses.sum() + if torch.is_tensor(num_items_in_batch): + num_items_in_batch = num_items_in_batch.to(total_loss.device) + loss = total_loss / num_items_in_batch + else: + loss = weighted_losses.mean() + return loss + + def nested_detach( tensors: Union["torch.Tensor", list["torch.Tensor"], tuple["torch.Tensor"], dict[str, "torch.Tensor"]], clone: bool = False,