refactor data preprocessing, fix mllm rlhf

Former-commit-id: 3a023bca2a
This commit is contained in:
hiyouga
2024-05-24 04:08:25 +08:00
parent 77b5779746
commit 3e729798df
15 changed files with 572 additions and 464 deletions

View File

@@ -4,7 +4,7 @@ from types import MethodType
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
import torch
from transformers import BatchEncoding, Trainer
from transformers import Trainer
from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model
@@ -108,14 +108,8 @@ class CustomDPOTrainer(DPOTrainer):
Otherwise the average log probabilities.
"""
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
all_logits: "torch.Tensor" = model(
input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
return_dict=True,
use_cache=False,
).logits.to(torch.float32)
batch_copied = {k: v.detach().clone() for k, v in batch.items()} # avoid error
all_logits: "torch.Tensor" = model(**batch_copied, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logps = self.get_batch_logps(
logits=all_logits,

View File

@@ -104,19 +104,23 @@ class CustomKTOTrainer(KTOTrainer):
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
with torch.no_grad():
kl_logits = model(
input_ids=batch["kl_input_ids"],
attention_mask=batch["kl_attention_mask"],
return_dict=True,
use_cache=False,
).logits.to(torch.float32)
kl_model_inputs = {"input_ids": batch["kl_input_ids"], "attention_mask": batch["kl_attention_mask"]}
if "pixel_values" in batch:
kl_model_inputs["pixel_values"] = batch["pixel_values"]
target_logits = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
return_dict=True,
use_cache=False,
).logits.to(torch.float32)
if "kl_token_type_ids" in batch:
kl_model_inputs["token_type_ids"] = batch["kl_token_type_ids"]
kl_logits = model(**kl_model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
model_inputs = {"input_ids": batch["input_ids"], "attention_mask": batch["attention_mask"]}
if "pixel_values" in batch:
model_inputs["pixel_values"] = batch["pixel_values"]
if "token_type_ids" in batch:
model_inputs["token_type_ids"] = batch["token_type_ids"]
target_logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
target_logps = self.get_batch_logps(
logits=target_logits,

View File

@@ -85,9 +85,7 @@ class CustomORPOTrainer(DPOTrainer):
r"""
Computes the average log probabilities of the labels under the given logits.
"""
all_logits: "torch.Tensor" = model(
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True, use_cache=False
).logits.to(torch.float32)
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logps = self.get_batch_logps(
logits=all_logits,