mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-19 12:12:48 +08:00
[feature] Support MPO (#8930)
This commit is contained in:
parent
e557e71023
commit
2b66b4df43
@ -1663,20 +1663,36 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"MiMo-7B-VL-Instruct": {
|
||||
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT",
|
||||
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-SFT",
|
||||
},
|
||||
"MiMo-7B-VL-RL": {
|
||||
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-RL",
|
||||
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-RL",
|
||||
},
|
||||
"MiMo-VL-7B-RL-2508": {
|
||||
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-RL-2508",
|
||||
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-RL-2508"
|
||||
}
|
||||
},
|
||||
template="mimo_vl",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"MiMo-7B-VL-Instruct": {
|
||||
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT",
|
||||
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-SFT",
|
||||
},
|
||||
"MiMo-VL-7B-SFT-2508": {
|
||||
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT-2508",
|
||||
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT-2508"
|
||||
},
|
||||
},
|
||||
template="qwen2_vl",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"MiniCPM-2B-SFT-Chat": {
|
||||
|
@ -134,6 +134,10 @@ class RLHFArguments:
|
||||
default=0.0,
|
||||
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
|
||||
)
|
||||
pref_bco_weight: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "The Binary Classifier Optimization coefficient in DPO training."},
|
||||
)
|
||||
pref_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair", "orpo", "simpo"] = field(
|
||||
default="sigmoid",
|
||||
metadata={"help": "The type of DPO loss to use."},
|
||||
|
@ -78,6 +78,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
self.beta = finetuning_args.pref_beta
|
||||
self.loss_type = finetuning_args.pref_loss
|
||||
self.ftx_gamma = finetuning_args.pref_ftx
|
||||
self.bco_gemma = finetuning_args.pref_bco_weight
|
||||
self.label_smoothing = finetuning_args.dpo_label_smoothing
|
||||
self.simpo_gamma = finetuning_args.simpo_gamma
|
||||
self.ld_alpha = finetuning_args.ld_alpha
|
||||
@ -108,6 +109,10 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
if self.bco_gemma >= 1e-6:
|
||||
from trl.trainer import RunningMoments
|
||||
self.running = RunningMoments(self.accelerator)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
@ -151,6 +156,25 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
simpo_loss = -F.logsigmoid(self.beta * logits)
|
||||
return simpo_loss
|
||||
|
||||
def bco_loss(
|
||||
self,
|
||||
chosen_logps: "torch.Tensor",
|
||||
rejected_logps: "torch.Tensor",
|
||||
reference_chosen_logps: "torch.Tensor",
|
||||
reference_rejected_logps: "torch.Tensor"
|
||||
) -> "torch.Tensor":
|
||||
chosen_logratios = chosen_logps - reference_chosen_logps
|
||||
rejected_logratios = rejected_logps - reference_rejected_logps
|
||||
chosen_rewards = self.beta * chosen_logratios
|
||||
rejected_rewards = self.beta * rejected_logratios
|
||||
rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
|
||||
self.running.update(rewards) # update baseline
|
||||
delta = self.running.mean
|
||||
bco_loss = -F.logsigmoid((self.beta * chosen_logratios) - delta) - F.logsigmoid(
|
||||
-(self.beta * rejected_logratios - delta)
|
||||
)
|
||||
return bco_loss
|
||||
|
||||
def compute_preference_loss(
|
||||
self,
|
||||
policy_chosen_logps: "torch.Tensor",
|
||||
@ -171,8 +195,17 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach()
|
||||
else:
|
||||
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
|
||||
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
|
||||
)
|
||||
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
|
||||
)
|
||||
|
||||
if self.bco_gemma > 1e-6:
|
||||
bco_losses = self.bco_loss(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps
|
||||
)
|
||||
losses += bco_losses * self.bco_gemma
|
||||
|
||||
return losses, chosen_rewards, rejected_rewards
|
||||
|
||||
@ -253,6 +286,9 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
sft_loss = -policy_chosen_logps_avg
|
||||
if self.ftx_gamma > 1e-6:
|
||||
losses += self.ftx_gamma * sft_loss
|
||||
if self.bco_gemma > 1e-6:
|
||||
# re-weigthing for MPO
|
||||
losses /= (self.ftx_gamma + self.bco_gemma + 1.0)
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().item()
|
||||
|
Loading…
x
Reference in New Issue
Block a user