mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-14 19:06:26 +08:00
refactor data preprocessing, fix mllm rlhf
Former-commit-id: 3a023bca2a
This commit is contained in:
@@ -4,7 +4,7 @@ from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import BatchEncoding, Trainer
|
||||
from transformers import Trainer
|
||||
from trl import DPOTrainer
|
||||
from trl.trainer.utils import disable_dropout_in_model
|
||||
|
||||
@@ -108,14 +108,8 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
Otherwise the average log probabilities.
|
||||
"""
|
||||
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
|
||||
|
||||
all_logits: "torch.Tensor" = model(
|
||||
input_ids=batch_copied["input_ids"],
|
||||
attention_mask=batch_copied["attention_mask"],
|
||||
return_dict=True,
|
||||
use_cache=False,
|
||||
).logits.to(torch.float32)
|
||||
batch_copied = {k: v.detach().clone() for k, v in batch.items()} # avoid error
|
||||
all_logits: "torch.Tensor" = model(**batch_copied, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
|
||||
all_logps = self.get_batch_logps(
|
||||
logits=all_logits,
|
||||
|
||||
@@ -104,19 +104,23 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
with torch.no_grad():
|
||||
kl_logits = model(
|
||||
input_ids=batch["kl_input_ids"],
|
||||
attention_mask=batch["kl_attention_mask"],
|
||||
return_dict=True,
|
||||
use_cache=False,
|
||||
).logits.to(torch.float32)
|
||||
kl_model_inputs = {"input_ids": batch["kl_input_ids"], "attention_mask": batch["kl_attention_mask"]}
|
||||
if "pixel_values" in batch:
|
||||
kl_model_inputs["pixel_values"] = batch["pixel_values"]
|
||||
|
||||
target_logits = model(
|
||||
input_ids=batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
return_dict=True,
|
||||
use_cache=False,
|
||||
).logits.to(torch.float32)
|
||||
if "kl_token_type_ids" in batch:
|
||||
kl_model_inputs["token_type_ids"] = batch["kl_token_type_ids"]
|
||||
|
||||
kl_logits = model(**kl_model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
|
||||
model_inputs = {"input_ids": batch["input_ids"], "attention_mask": batch["attention_mask"]}
|
||||
if "pixel_values" in batch:
|
||||
model_inputs["pixel_values"] = batch["pixel_values"]
|
||||
|
||||
if "token_type_ids" in batch:
|
||||
model_inputs["token_type_ids"] = batch["token_type_ids"]
|
||||
|
||||
target_logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
|
||||
target_logps = self.get_batch_logps(
|
||||
logits=target_logits,
|
||||
|
||||
@@ -85,9 +85,7 @@ class CustomORPOTrainer(DPOTrainer):
|
||||
r"""
|
||||
Computes the average log probabilities of the labels under the given logits.
|
||||
"""
|
||||
all_logits: "torch.Tensor" = model(
|
||||
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True, use_cache=False
|
||||
).logits.to(torch.float32)
|
||||
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
|
||||
all_logps = self.get_batch_logps(
|
||||
logits=all_logits,
|
||||
|
||||
Reference in New Issue
Block a user