mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 13:42:51 +08:00
clean kto trainer
Former-commit-id: 900e1ea622a2ffa45c5e2a359471962563fabca7
This commit is contained in:
parent
87e71df597
commit
14f6cc2b7c
@ -1,7 +1,7 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
@ -101,42 +101,39 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
return -all_logps
|
return -all_logps
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
|
||||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
) -> Tuple["torch.Tensor", "torch.Tensor"]:
|
||||||
with torch.no_grad():
|
r"""
|
||||||
kl_model_inputs = {"input_ids": batch["kl_input_ids"], "attention_mask": batch["kl_attention_mask"]}
|
Runs forward pass and computes the log probabilities.
|
||||||
if "pixel_values" in batch:
|
"""
|
||||||
kl_model_inputs["pixel_values"] = batch["pixel_values"]
|
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
|
||||||
|
model_inputs = {
|
||||||
if "kl_token_type_ids" in batch:
|
"input_ids": batch["{}input_ids".format(prefix)],
|
||||||
kl_model_inputs["token_type_ids"] = batch["kl_token_type_ids"]
|
"attention_mask": batch["{}attention_mask".format(prefix)],
|
||||||
|
}
|
||||||
kl_logits = model(**kl_model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
|
|
||||||
|
|
||||||
model_inputs = {"input_ids": batch["input_ids"], "attention_mask": batch["attention_mask"]}
|
|
||||||
if "pixel_values" in batch:
|
if "pixel_values" in batch:
|
||||||
model_inputs["pixel_values"] = batch["pixel_values"]
|
model_inputs["pixel_values"] = batch["pixel_values"]
|
||||||
|
|
||||||
if "token_type_ids" in batch:
|
if "{}token_type_ids".format(prefix) in batch:
|
||||||
model_inputs["token_type_ids"] = batch["token_type_ids"]
|
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
|
||||||
|
|
||||||
target_logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
|
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||||
|
|
||||||
target_logps = self.get_batch_logps(
|
logps = self.get_batch_logps(
|
||||||
logits=target_logits,
|
logits=logits,
|
||||||
labels=batch["labels"],
|
labels=batch["{}labels".format(prefix)],
|
||||||
average_log_prob=False,
|
average_log_prob=False,
|
||||||
is_encoder_decoder=self.is_encoder_decoder,
|
is_encoder_decoder=self.is_encoder_decoder,
|
||||||
label_pad_token_id=self.label_pad_token_id,
|
label_pad_token_id=self.label_pad_token_id,
|
||||||
)
|
)
|
||||||
|
return logits, logps
|
||||||
|
|
||||||
kl_logps = self.get_batch_logps(
|
def concatenated_forward(
|
||||||
logits=kl_logits,
|
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||||
labels=batch["kl_labels"],
|
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||||
average_log_prob=False,
|
target_logits, target_logps = self.forward(model, batch)
|
||||||
is_encoder_decoder=self.is_encoder_decoder,
|
with torch.no_grad():
|
||||||
label_pad_token_id=self.label_pad_token_id,
|
_, kl_logps = self.forward(model, batch, prefix="kl_")
|
||||||
)
|
|
||||||
|
|
||||||
if len(target_logps) != len(batch["kto_tags"]):
|
if len(target_logps) != len(batch["kto_tags"]):
|
||||||
raise ValueError("Mismatched shape of inputs and labels.")
|
raise ValueError("Mismatched shape of inputs and labels.")
|
||||||
@ -152,6 +149,30 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
|
|
||||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps
|
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps
|
||||||
|
|
||||||
|
def compute_reference_log_probs(
|
||||||
|
self, batch: Dict[str, "torch.Tensor"]
|
||||||
|
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||||
|
r"""
|
||||||
|
Computes log probabilities of the reference model.
|
||||||
|
"""
|
||||||
|
if self.ref_model is None:
|
||||||
|
ref_model = self.model
|
||||||
|
ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
|
||||||
|
else:
|
||||||
|
ref_model = self.ref_model
|
||||||
|
ref_context = nullcontext()
|
||||||
|
|
||||||
|
with torch.no_grad(), ref_context:
|
||||||
|
(
|
||||||
|
reference_chosen_logps,
|
||||||
|
reference_rejected_logps,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
reference_kl_logps,
|
||||||
|
) = self.concatenated_forward(ref_model, batch)
|
||||||
|
|
||||||
|
return reference_chosen_logps, reference_rejected_logps, reference_kl_logps
|
||||||
|
|
||||||
def get_batch_loss_metrics(
|
def get_batch_loss_metrics(
|
||||||
self,
|
self,
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
@ -167,25 +188,9 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
policy_chosen_logits,
|
policy_chosen_logits,
|
||||||
_,
|
_,
|
||||||
policy_kl_logps,
|
policy_kl_logps,
|
||||||
) = self.forward(model, batch)
|
) = self.concatenated_forward(model, batch)
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
if self.ref_model is None:
|
|
||||||
ref_model = self.model
|
|
||||||
ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
|
|
||||||
else:
|
|
||||||
ref_model = self.ref_model
|
|
||||||
ref_context = nullcontext()
|
|
||||||
|
|
||||||
with ref_context:
|
|
||||||
(
|
|
||||||
reference_chosen_logps,
|
|
||||||
reference_rejected_logps,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
reference_kl_logps,
|
|
||||||
) = self.forward(ref_model, batch)
|
|
||||||
|
|
||||||
|
reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(batch)
|
||||||
losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
|
losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
|
||||||
policy_chosen_logps,
|
policy_chosen_logps,
|
||||||
policy_rejected_logps,
|
policy_rejected_logps,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user