mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[feature] Support MPO (#8930)
This commit is contained in:
		
							parent
							
								
									41648020db
								
							
						
					
					
						commit
						936f4fd78e
					
				@ -1663,20 +1663,36 @@ register_model_group(
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "MiMo-7B-VL-Instruct": {
 | 
			
		||||
            DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-SFT",
 | 
			
		||||
        },
 | 
			
		||||
        "MiMo-7B-VL-RL": {
 | 
			
		||||
            DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-RL",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-RL",
 | 
			
		||||
        },
 | 
			
		||||
        "MiMo-VL-7B-RL-2508": {
 | 
			
		||||
            DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-RL-2508",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-RL-2508"
 | 
			
		||||
        }
 | 
			
		||||
    },
 | 
			
		||||
    template="mimo_vl",
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "MiMo-7B-VL-Instruct": {
 | 
			
		||||
            DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-SFT",
 | 
			
		||||
        },
 | 
			
		||||
        "MiMo-VL-7B-SFT-2508": {
 | 
			
		||||
            DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT-2508",
 | 
			
		||||
            DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT-2508"
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="qwen2_vl",
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "MiniCPM-2B-SFT-Chat": {
 | 
			
		||||
 | 
			
		||||
@ -134,6 +134,10 @@ class RLHFArguments:
 | 
			
		||||
        default=0.0,
 | 
			
		||||
        metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
 | 
			
		||||
    )
 | 
			
		||||
    pref_bco_weight: float = field(
 | 
			
		||||
        default=0.0,
 | 
			
		||||
        metadata={"help": "The Binary Classifier Optimization coefficient in DPO training."},
 | 
			
		||||
    )
 | 
			
		||||
    pref_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair", "orpo", "simpo"] = field(
 | 
			
		||||
        default="sigmoid",
 | 
			
		||||
        metadata={"help": "The type of DPO loss to use."},
 | 
			
		||||
 | 
			
		||||
@ -78,6 +78,7 @@ class CustomDPOTrainer(DPOTrainer):
 | 
			
		||||
        self.beta = finetuning_args.pref_beta
 | 
			
		||||
        self.loss_type = finetuning_args.pref_loss
 | 
			
		||||
        self.ftx_gamma = finetuning_args.pref_ftx
 | 
			
		||||
        self.bco_gemma = finetuning_args.pref_bco_weight
 | 
			
		||||
        self.label_smoothing = finetuning_args.dpo_label_smoothing
 | 
			
		||||
        self.simpo_gamma = finetuning_args.simpo_gamma
 | 
			
		||||
        self.ld_alpha = finetuning_args.ld_alpha
 | 
			
		||||
@ -108,6 +109,10 @@ class CustomDPOTrainer(DPOTrainer):
 | 
			
		||||
            self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
 | 
			
		||||
            self.add_callback(BAdamCallback)
 | 
			
		||||
 | 
			
		||||
        if self.bco_gemma >= 1e-6:
 | 
			
		||||
            from trl.trainer import RunningMoments
 | 
			
		||||
            self.running = RunningMoments(self.accelerator)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def create_optimizer(self) -> "torch.optim.Optimizer":
 | 
			
		||||
        if self.optimizer is None:
 | 
			
		||||
@ -151,6 +156,25 @@ class CustomDPOTrainer(DPOTrainer):
 | 
			
		||||
        simpo_loss = -F.logsigmoid(self.beta * logits)
 | 
			
		||||
        return simpo_loss
 | 
			
		||||
 | 
			
		||||
    def bco_loss(
 | 
			
		||||
        self,
 | 
			
		||||
        chosen_logps: "torch.Tensor",
 | 
			
		||||
        rejected_logps: "torch.Tensor",
 | 
			
		||||
        reference_chosen_logps: "torch.Tensor",
 | 
			
		||||
        reference_rejected_logps: "torch.Tensor"
 | 
			
		||||
    ) -> "torch.Tensor":
 | 
			
		||||
        chosen_logratios = chosen_logps - reference_chosen_logps
 | 
			
		||||
        rejected_logratios = rejected_logps - reference_rejected_logps
 | 
			
		||||
        chosen_rewards = self.beta * chosen_logratios
 | 
			
		||||
        rejected_rewards = self.beta * rejected_logratios
 | 
			
		||||
        rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
 | 
			
		||||
        self.running.update(rewards) # update baseline
 | 
			
		||||
        delta = self.running.mean
 | 
			
		||||
        bco_loss = -F.logsigmoid((self.beta * chosen_logratios) - delta) - F.logsigmoid(
 | 
			
		||||
            -(self.beta * rejected_logratios - delta)
 | 
			
		||||
        )
 | 
			
		||||
        return bco_loss
 | 
			
		||||
 | 
			
		||||
    def compute_preference_loss(
 | 
			
		||||
        self,
 | 
			
		||||
        policy_chosen_logps: "torch.Tensor",
 | 
			
		||||
@ -171,8 +195,17 @@ class CustomDPOTrainer(DPOTrainer):
 | 
			
		||||
            rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach()
 | 
			
		||||
        else:
 | 
			
		||||
            losses, chosen_rewards, rejected_rewards = self.dpo_loss(
 | 
			
		||||
                policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
 | 
			
		||||
            )
 | 
			
		||||
                    policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if self.bco_gemma > 1e-6:
 | 
			
		||||
                bco_losses = self.bco_loss(
 | 
			
		||||
                    policy_chosen_logps,
 | 
			
		||||
                    policy_rejected_logps,
 | 
			
		||||
                    reference_chosen_logps,
 | 
			
		||||
                    reference_rejected_logps
 | 
			
		||||
                )
 | 
			
		||||
                losses += bco_losses * self.bco_gemma
 | 
			
		||||
 | 
			
		||||
        return losses, chosen_rewards, rejected_rewards
 | 
			
		||||
 | 
			
		||||
@ -253,6 +286,9 @@ class CustomDPOTrainer(DPOTrainer):
 | 
			
		||||
        sft_loss = -policy_chosen_logps_avg
 | 
			
		||||
        if self.ftx_gamma > 1e-6:
 | 
			
		||||
            losses += self.ftx_gamma * sft_loss
 | 
			
		||||
            if self.bco_gemma > 1e-6:
 | 
			
		||||
                # re-weigthing for MPO
 | 
			
		||||
                losses /= (self.ftx_gamma + self.bco_gemma + 1.0)
 | 
			
		||||
 | 
			
		||||
        prefix = "eval_" if train_eval == "eval" else ""
 | 
			
		||||
        metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().item()
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user