[misc] lint (#9636)

This commit is contained in:
Yaowei Zheng
2025-12-20 16:19:39 +08:00
committed by GitHub
parent b0d49e137f
commit 0894b4f37e
6 changed files with 28 additions and 31 deletions

View File

@@ -218,9 +218,10 @@ class CustomDPOTrainer(DPOTrainer):
if self.finetuning_args.use_ref_model:
batch = nested_detach(batch, clone=True) # avoid error
labels = batch.pop("labels") # dpo do not need compute loss in forward
all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logps, valid_length = get_batch_logps(
logits=all_logits, labels=batch["labels"], ld_alpha=(self.ld_alpha if not is_ref_model else None)
logits=all_logits, labels=labels, ld_alpha=(self.ld_alpha if not is_ref_model else None)
)
if self.loss_type in ["ipo", "orpo", "simpo"]:
all_logps = all_logps / valid_length