This commit is contained in:
hiyouga
2025-01-07 06:30:44 +00:00
parent b832ed9a60
commit 870f23d7ea
3 changed files with 32 additions and 10 deletions

View File

@@ -31,7 +31,7 @@ from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach
if TYPE_CHECKING:
@@ -193,7 +193,7 @@ class CustomDPOTrainer(DPOTrainer):
Otherwise the average log probabilities.
"""
if self.finetuning_args.use_ref_model:
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
batch = nested_detach(batch, clone=True) # avoid error
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])