From 5af92971bc69d2368ae0005a9894cb814a0ec8f7 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 2 Sep 2024 10:15:29 +0800 Subject: [PATCH] fix trainer predict Former-commit-id: 99fd9637bdc25f41fd1abc8a162f1069cb9060d4 --- src/llamafactory/train/sft/trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index d08b2eda..0a5c26f4 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -87,9 +87,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): Subclass and override to inject custom behavior. """ - labels = inputs["labels"].detach().clone().cpu() if "labels" in inputs else None # backup labels (d2h) + labels = inputs["labels"] if "labels" in inputs else None if self.args.predict_with_generate: assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor." + labels = labels.detach().clone() if labels is not None else None # backup labels prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1) if prompt_len > label_len: inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"]) @@ -101,7 +102,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): ) if generated_tokens is not None and self.args.predict_with_generate: generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id - generated_tokens = generated_tokens.contiguous().cpu() # d2h + generated_tokens = generated_tokens.contiguous() return loss, generated_tokens, labels