clean kto trainer

Former-commit-id: 900e1ea622a2ffa45c5e2a359471962563fabca7
This commit is contained in:
hiyouga 2024-05-28 21:43:26 +08:00
parent 87e71df597
commit 14f6cc2b7c

View File

@ -1,7 +1,7 @@
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
import torch import torch
from transformers import Trainer from transformers import Trainer
@ -101,42 +101,39 @@ class CustomKTOTrainer(KTOTrainer):
return -all_logps return -all_logps
def forward( def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> Tuple["torch.Tensor", "torch.Tensor"]:
with torch.no_grad(): r"""
kl_model_inputs = {"input_ids": batch["kl_input_ids"], "attention_mask": batch["kl_attention_mask"]} Runs forward pass and computes the log probabilities.
if "pixel_values" in batch: """
kl_model_inputs["pixel_values"] = batch["pixel_values"] batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
model_inputs = {
if "kl_token_type_ids" in batch: "input_ids": batch["{}input_ids".format(prefix)],
kl_model_inputs["token_type_ids"] = batch["kl_token_type_ids"] "attention_mask": batch["{}attention_mask".format(prefix)],
}
kl_logits = model(**kl_model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
model_inputs = {"input_ids": batch["input_ids"], "attention_mask": batch["attention_mask"]}
if "pixel_values" in batch: if "pixel_values" in batch:
model_inputs["pixel_values"] = batch["pixel_values"] model_inputs["pixel_values"] = batch["pixel_values"]
if "token_type_ids" in batch: if "{}token_type_ids".format(prefix) in batch:
model_inputs["token_type_ids"] = batch["token_type_ids"] model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
target_logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32) logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
target_logps = self.get_batch_logps( logps = self.get_batch_logps(
logits=target_logits, logits=logits,
labels=batch["labels"], labels=batch["{}labels".format(prefix)],
average_log_prob=False, average_log_prob=False,
is_encoder_decoder=self.is_encoder_decoder, is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id, label_pad_token_id=self.label_pad_token_id,
) )
return logits, logps
kl_logps = self.get_batch_logps( def concatenated_forward(
logits=kl_logits, self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
labels=batch["kl_labels"], ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
average_log_prob=False, target_logits, target_logps = self.forward(model, batch)
is_encoder_decoder=self.is_encoder_decoder, with torch.no_grad():
label_pad_token_id=self.label_pad_token_id, _, kl_logps = self.forward(model, batch, prefix="kl_")
)
if len(target_logps) != len(batch["kto_tags"]): if len(target_logps) != len(batch["kto_tags"]):
raise ValueError("Mismatched shape of inputs and labels.") raise ValueError("Mismatched shape of inputs and labels.")
@ -152,6 +149,30 @@ class CustomKTOTrainer(KTOTrainer):
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps
def compute_reference_log_probs(
self, batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Computes log probabilities of the reference model.
"""
if self.ref_model is None:
ref_model = self.model
ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
else:
ref_model = self.ref_model
ref_context = nullcontext()
with torch.no_grad(), ref_context:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
reference_kl_logps,
) = self.concatenated_forward(ref_model, batch)
return reference_chosen_logps, reference_rejected_logps, reference_kl_logps
def get_batch_loss_metrics( def get_batch_loss_metrics(
self, self,
model: "PreTrainedModel", model: "PreTrainedModel",
@ -167,25 +188,9 @@ class CustomKTOTrainer(KTOTrainer):
policy_chosen_logits, policy_chosen_logits,
_, _,
policy_kl_logps, policy_kl_logps,
) = self.forward(model, batch) ) = self.concatenated_forward(model, batch)
with torch.no_grad():
if self.ref_model is None:
ref_model = self.model
ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
else:
ref_model = self.ref_model
ref_context = nullcontext()
with ref_context:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
reference_kl_logps,
) = self.forward(ref_model, batch)
reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(batch)
losses, chosen_rewards, rejected_rewards, kl = self.kto_loss( losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
policy_chosen_logps, policy_chosen_logps,
policy_rejected_logps, policy_rejected_logps,