mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 13:42:51 +08:00
update dpo, kto trainer
Former-commit-id: 7c8e01bb74bb2d2da5dba5059a9c262e4730b802
This commit is contained in:
parent
14f6cc2b7c
commit
bfac965f9c
@ -7,7 +7,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
from trl.trainer.utils import disable_dropout_in_model
|
from trl.trainer import disable_dropout_in_model
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ..utils import create_custom_optimzer, create_custom_scheduler
|
from ..utils import create_custom_optimzer, create_custom_scheduler
|
||||||
@ -179,7 +179,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
|
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
|
||||||
|
|
||||||
def compute_reference_log_probs(
|
def compute_reference_log_probs(
|
||||||
self, batch: Dict[str, "torch.Tensor"]
|
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||||
) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
|
) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
|
||||||
r"""
|
r"""
|
||||||
Computes log probabilities of the reference model.
|
Computes log probabilities of the reference model.
|
||||||
@ -188,8 +188,8 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
if self.ref_model is None:
|
if self.ref_model is None:
|
||||||
ref_model = self.model
|
ref_model = model
|
||||||
ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
|
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
|
||||||
else:
|
else:
|
||||||
ref_model = self.ref_model
|
ref_model = self.ref_model
|
||||||
ref_context = nullcontext()
|
ref_context = nullcontext()
|
||||||
@ -221,7 +221,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
policy_rejected_logits,
|
policy_rejected_logits,
|
||||||
) = self.concatenated_forward(model, batch)
|
) = self.concatenated_forward(model, batch)
|
||||||
|
|
||||||
reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(batch)
|
reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(model, batch)
|
||||||
losses, chosen_rewards, rejected_rewards = self.compute_preference_loss(
|
losses, chosen_rewards, rejected_rewards = self.compute_preference_loss(
|
||||||
policy_chosen_logps,
|
policy_chosen_logps,
|
||||||
policy_rejected_logps,
|
policy_rejected_logps,
|
||||||
|
@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from trl import KTOTrainer
|
from trl import KTOTrainer
|
||||||
from trl.trainer.utils import disable_dropout_in_model
|
from trl.trainer import disable_dropout_in_model
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ..utils import create_custom_optimzer, create_custom_scheduler
|
from ..utils import create_custom_optimzer, create_custom_scheduler
|
||||||
@ -150,14 +150,14 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps
|
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps
|
||||||
|
|
||||||
def compute_reference_log_probs(
|
def compute_reference_log_probs(
|
||||||
self, batch: Dict[str, "torch.Tensor"]
|
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||||
r"""
|
r"""
|
||||||
Computes log probabilities of the reference model.
|
Computes log probabilities of the reference model.
|
||||||
"""
|
"""
|
||||||
if self.ref_model is None:
|
if self.ref_model is None:
|
||||||
ref_model = self.model
|
ref_model = model
|
||||||
ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
|
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
|
||||||
else:
|
else:
|
||||||
ref_model = self.ref_model
|
ref_model = self.ref_model
|
||||||
ref_context = nullcontext()
|
ref_context = nullcontext()
|
||||||
@ -190,7 +190,9 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
policy_kl_logps,
|
policy_kl_logps,
|
||||||
) = self.concatenated_forward(model, batch)
|
) = self.concatenated_forward(model, batch)
|
||||||
|
|
||||||
reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(batch)
|
reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(
|
||||||
|
model, batch
|
||||||
|
)
|
||||||
losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
|
losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
|
||||||
policy_chosen_logps,
|
policy_chosen_logps,
|
||||||
policy_rejected_logps,
|
policy_rejected_logps,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user