[feature] Support MPO (#8930)

This commit is contained in:
Kingsley 2025-08-15 15:09:59 +08:00 committed by GitHub
parent e557e71023
commit 2b66b4df43
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 62 additions and 6 deletions

View File

@ -1663,20 +1663,36 @@ register_model_group(
register_model_group(
models={
"MiMo-7B-VL-Instruct": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-SFT",
},
"MiMo-7B-VL-RL": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-RL",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-RL",
},
"MiMo-VL-7B-RL-2508": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-RL-2508",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-RL-2508"
}
},
template="mimo_vl",
multimodal=True,
)
register_model_group(
models={
"MiMo-7B-VL-Instruct": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-SFT",
},
"MiMo-VL-7B-SFT-2508": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT-2508",
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT-2508"
},
},
template="qwen2_vl",
multimodal=True,
)
register_model_group(
models={
"MiniCPM-2B-SFT-Chat": {

View File

@ -134,6 +134,10 @@ class RLHFArguments:
default=0.0,
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
)
pref_bco_weight: float = field(
default=0.0,
metadata={"help": "The Binary Classifier Optimization coefficient in DPO training."},
)
pref_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair", "orpo", "simpo"] = field(
default="sigmoid",
metadata={"help": "The type of DPO loss to use."},

View File

@ -78,6 +78,7 @@ class CustomDPOTrainer(DPOTrainer):
self.beta = finetuning_args.pref_beta
self.loss_type = finetuning_args.pref_loss
self.ftx_gamma = finetuning_args.pref_ftx
self.bco_gemma = finetuning_args.pref_bco_weight
self.label_smoothing = finetuning_args.dpo_label_smoothing
self.simpo_gamma = finetuning_args.simpo_gamma
self.ld_alpha = finetuning_args.ld_alpha
@ -108,6 +109,10 @@ class CustomDPOTrainer(DPOTrainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
if self.bco_gemma >= 1e-6:
from trl.trainer import RunningMoments
self.running = RunningMoments(self.accelerator)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
@ -151,6 +156,25 @@ class CustomDPOTrainer(DPOTrainer):
simpo_loss = -F.logsigmoid(self.beta * logits)
return simpo_loss
def bco_loss(
self,
chosen_logps: "torch.Tensor",
rejected_logps: "torch.Tensor",
reference_chosen_logps: "torch.Tensor",
reference_rejected_logps: "torch.Tensor"
) -> "torch.Tensor":
chosen_logratios = chosen_logps - reference_chosen_logps
rejected_logratios = rejected_logps - reference_rejected_logps
chosen_rewards = self.beta * chosen_logratios
rejected_rewards = self.beta * rejected_logratios
rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
self.running.update(rewards) # update baseline
delta = self.running.mean
bco_loss = -F.logsigmoid((self.beta * chosen_logratios) - delta) - F.logsigmoid(
-(self.beta * rejected_logratios - delta)
)
return bco_loss
def compute_preference_loss(
self,
policy_chosen_logps: "torch.Tensor",
@ -171,8 +195,17 @@ class CustomDPOTrainer(DPOTrainer):
rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach()
else:
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
)
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
)
if self.bco_gemma > 1e-6:
bco_losses = self.bco_loss(
policy_chosen_logps,
policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps
)
losses += bco_losses * self.bco_gemma
return losses, chosen_rewards, rejected_rewards
@ -253,6 +286,9 @@ class CustomDPOTrainer(DPOTrainer):
sft_loss = -policy_chosen_logps_avg
if self.ftx_gamma > 1e-6:
losses += self.ftx_gamma * sft_loss
if self.bco_gemma > 1e-6:
# re-weigthing for MPO
losses /= (self.ftx_gamma + self.bco_gemma + 1.0)
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().item()