mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-19 12:12:48 +08:00
[feature] Support MPO (#8930)
This commit is contained in:
parent
e557e71023
commit
2b66b4df43
@ -1663,20 +1663,36 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"MiMo-7B-VL-Instruct": {
|
|
||||||
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT",
|
|
||||||
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-SFT",
|
|
||||||
},
|
|
||||||
"MiMo-7B-VL-RL": {
|
"MiMo-7B-VL-RL": {
|
||||||
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-RL",
|
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-RL",
|
||||||
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-RL",
|
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-RL",
|
||||||
},
|
},
|
||||||
|
"MiMo-VL-7B-RL-2508": {
|
||||||
|
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-RL-2508",
|
||||||
|
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-RL-2508"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
template="mimo_vl",
|
template="mimo_vl",
|
||||||
multimodal=True,
|
multimodal=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"MiMo-7B-VL-Instruct": {
|
||||||
|
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT",
|
||||||
|
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-SFT",
|
||||||
|
},
|
||||||
|
"MiMo-VL-7B-SFT-2508": {
|
||||||
|
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT-2508",
|
||||||
|
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT-2508"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="qwen2_vl",
|
||||||
|
multimodal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"MiniCPM-2B-SFT-Chat": {
|
"MiniCPM-2B-SFT-Chat": {
|
||||||
|
@ -134,6 +134,10 @@ class RLHFArguments:
|
|||||||
default=0.0,
|
default=0.0,
|
||||||
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
|
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
|
||||||
)
|
)
|
||||||
|
pref_bco_weight: float = field(
|
||||||
|
default=0.0,
|
||||||
|
metadata={"help": "The Binary Classifier Optimization coefficient in DPO training."},
|
||||||
|
)
|
||||||
pref_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair", "orpo", "simpo"] = field(
|
pref_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair", "orpo", "simpo"] = field(
|
||||||
default="sigmoid",
|
default="sigmoid",
|
||||||
metadata={"help": "The type of DPO loss to use."},
|
metadata={"help": "The type of DPO loss to use."},
|
||||||
|
@ -78,6 +78,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
self.beta = finetuning_args.pref_beta
|
self.beta = finetuning_args.pref_beta
|
||||||
self.loss_type = finetuning_args.pref_loss
|
self.loss_type = finetuning_args.pref_loss
|
||||||
self.ftx_gamma = finetuning_args.pref_ftx
|
self.ftx_gamma = finetuning_args.pref_ftx
|
||||||
|
self.bco_gemma = finetuning_args.pref_bco_weight
|
||||||
self.label_smoothing = finetuning_args.dpo_label_smoothing
|
self.label_smoothing = finetuning_args.dpo_label_smoothing
|
||||||
self.simpo_gamma = finetuning_args.simpo_gamma
|
self.simpo_gamma = finetuning_args.simpo_gamma
|
||||||
self.ld_alpha = finetuning_args.ld_alpha
|
self.ld_alpha = finetuning_args.ld_alpha
|
||||||
@ -108,6 +109,10 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.add_callback(BAdamCallback)
|
self.add_callback(BAdamCallback)
|
||||||
|
|
||||||
|
if self.bco_gemma >= 1e-6:
|
||||||
|
from trl.trainer import RunningMoments
|
||||||
|
self.running = RunningMoments(self.accelerator)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||||
if self.optimizer is None:
|
if self.optimizer is None:
|
||||||
@ -151,6 +156,25 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
simpo_loss = -F.logsigmoid(self.beta * logits)
|
simpo_loss = -F.logsigmoid(self.beta * logits)
|
||||||
return simpo_loss
|
return simpo_loss
|
||||||
|
|
||||||
|
def bco_loss(
|
||||||
|
self,
|
||||||
|
chosen_logps: "torch.Tensor",
|
||||||
|
rejected_logps: "torch.Tensor",
|
||||||
|
reference_chosen_logps: "torch.Tensor",
|
||||||
|
reference_rejected_logps: "torch.Tensor"
|
||||||
|
) -> "torch.Tensor":
|
||||||
|
chosen_logratios = chosen_logps - reference_chosen_logps
|
||||||
|
rejected_logratios = rejected_logps - reference_rejected_logps
|
||||||
|
chosen_rewards = self.beta * chosen_logratios
|
||||||
|
rejected_rewards = self.beta * rejected_logratios
|
||||||
|
rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
|
||||||
|
self.running.update(rewards) # update baseline
|
||||||
|
delta = self.running.mean
|
||||||
|
bco_loss = -F.logsigmoid((self.beta * chosen_logratios) - delta) - F.logsigmoid(
|
||||||
|
-(self.beta * rejected_logratios - delta)
|
||||||
|
)
|
||||||
|
return bco_loss
|
||||||
|
|
||||||
def compute_preference_loss(
|
def compute_preference_loss(
|
||||||
self,
|
self,
|
||||||
policy_chosen_logps: "torch.Tensor",
|
policy_chosen_logps: "torch.Tensor",
|
||||||
@ -171,8 +195,17 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach()
|
rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach()
|
||||||
else:
|
else:
|
||||||
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
|
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
|
||||||
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
|
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.bco_gemma > 1e-6:
|
||||||
|
bco_losses = self.bco_loss(
|
||||||
|
policy_chosen_logps,
|
||||||
|
policy_rejected_logps,
|
||||||
|
reference_chosen_logps,
|
||||||
|
reference_rejected_logps
|
||||||
|
)
|
||||||
|
losses += bco_losses * self.bco_gemma
|
||||||
|
|
||||||
return losses, chosen_rewards, rejected_rewards
|
return losses, chosen_rewards, rejected_rewards
|
||||||
|
|
||||||
@ -253,6 +286,9 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
sft_loss = -policy_chosen_logps_avg
|
sft_loss = -policy_chosen_logps_avg
|
||||||
if self.ftx_gamma > 1e-6:
|
if self.ftx_gamma > 1e-6:
|
||||||
losses += self.ftx_gamma * sft_loss
|
losses += self.ftx_gamma * sft_loss
|
||||||
|
if self.bco_gemma > 1e-6:
|
||||||
|
# re-weigthing for MPO
|
||||||
|
losses /= (self.ftx_gamma + self.bco_gemma + 1.0)
|
||||||
|
|
||||||
prefix = "eval_" if train_eval == "eval" else ""
|
prefix = "eval_" if train_eval == "eval" else ""
|
||||||
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().item()
|
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().item()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user