diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index d08b2eda..0a5c26f4 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -87,9 +87,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): Subclass and override to inject custom behavior. """ - labels = inputs["labels"].detach().clone().cpu() if "labels" in inputs else None # backup labels (d2h) + labels = inputs["labels"] if "labels" in inputs else None if self.args.predict_with_generate: assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor." + labels = labels.detach().clone() if labels is not None else None # backup labels prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1) if prompt_len > label_len: inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"]) @@ -101,7 +102,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): ) if generated_tokens is not None and self.args.predict_with_generate: generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id - generated_tokens = generated_tokens.contiguous().cpu() # d2h + generated_tokens = generated_tokens.contiguous() return loss, generated_tokens, labels