From 2b66b4df43cfd8cdee5130e44270135113902569 Mon Sep 17 00:00:00 2001 From: Kingsley Date: Fri, 15 Aug 2025 15:09:59 +0800 Subject: [PATCH] [feature] Support MPO (#8930) --- src/llamafactory/extras/constants.py | 24 ++++++++++--- src/llamafactory/hparams/finetuning_args.py | 4 +++ src/llamafactory/train/dpo/trainer.py | 40 +++++++++++++++++++-- 3 files changed, 62 insertions(+), 6 deletions(-) diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 5467e772..586cd3e9 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -1663,20 +1663,36 @@ register_model_group( register_model_group( models={ - "MiMo-7B-VL-Instruct": { - DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT", - DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-SFT", - }, "MiMo-7B-VL-RL": { DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-RL", DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-RL", }, + "MiMo-VL-7B-RL-2508": { + DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-RL-2508", + DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-RL-2508" + } }, template="mimo_vl", multimodal=True, ) +register_model_group( + models={ + "MiMo-7B-VL-Instruct": { + DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT", + DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-SFT", + }, + "MiMo-VL-7B-SFT-2508": { + DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT-2508", + DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT-2508" + }, + }, + template="qwen2_vl", + multimodal=True, +) + + register_model_group( models={ "MiniCPM-2B-SFT-Chat": { diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 43596864..0a3a2f39 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -134,6 +134,10 @@ class RLHFArguments: default=0.0, metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}, ) + pref_bco_weight: float = field( + default=0.0, + metadata={"help": "The Binary Classifier Optimization coefficient in DPO training."}, + ) pref_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair", "orpo", "simpo"] = field( default="sigmoid", metadata={"help": "The type of DPO loss to use."}, diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 63822e88..d1840fae 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -78,6 +78,7 @@ class CustomDPOTrainer(DPOTrainer): self.beta = finetuning_args.pref_beta self.loss_type = finetuning_args.pref_loss self.ftx_gamma = finetuning_args.pref_ftx + self.bco_gemma = finetuning_args.pref_bco_weight self.label_smoothing = finetuning_args.dpo_label_smoothing self.simpo_gamma = finetuning_args.simpo_gamma self.ld_alpha = finetuning_args.ld_alpha @@ -108,6 +109,10 @@ class CustomDPOTrainer(DPOTrainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) + if self.bco_gemma >= 1e-6: + from trl.trainer import RunningMoments + self.running = RunningMoments(self.accelerator) + @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: @@ -151,6 +156,25 @@ class CustomDPOTrainer(DPOTrainer): simpo_loss = -F.logsigmoid(self.beta * logits) return simpo_loss + def bco_loss( + self, + chosen_logps: "torch.Tensor", + rejected_logps: "torch.Tensor", + reference_chosen_logps: "torch.Tensor", + reference_rejected_logps: "torch.Tensor" + ) -> "torch.Tensor": + chosen_logratios = chosen_logps - reference_chosen_logps + rejected_logratios = rejected_logps - reference_rejected_logps + chosen_rewards = self.beta * chosen_logratios + rejected_rewards = self.beta * rejected_logratios + rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach() + self.running.update(rewards) # update baseline + delta = self.running.mean + bco_loss = -F.logsigmoid((self.beta * chosen_logratios) - delta) - F.logsigmoid( + -(self.beta * rejected_logratios - delta) + ) + return bco_loss + def compute_preference_loss( self, policy_chosen_logps: "torch.Tensor", @@ -171,8 +195,17 @@ class CustomDPOTrainer(DPOTrainer): rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach() else: losses, chosen_rewards, rejected_rewards = self.dpo_loss( - policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps - ) + policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps + ) + + if self.bco_gemma > 1e-6: + bco_losses = self.bco_loss( + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps + ) + losses += bco_losses * self.bco_gemma return losses, chosen_rewards, rejected_rewards @@ -253,6 +286,9 @@ class CustomDPOTrainer(DPOTrainer): sft_loss = -policy_chosen_logps_avg if self.ftx_gamma > 1e-6: losses += self.ftx_gamma * sft_loss + if self.bco_gemma > 1e-6: + # re-weigthing for MPO + losses /= (self.ftx_gamma + self.bco_gemma + 1.0) prefix = "eval_" if train_eval == "eval" else "" metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().item()