From 9501c3308a01ecce03e952aadd10b509fa4e1411 Mon Sep 17 00:00:00 2001 From: pyx Date: Fri, 6 Mar 2026 18:44:57 +0800 Subject: [PATCH] =?UTF-8?q?[train]=20fix=20compatibility=20issue=20with=20?= =?UTF-8?q?HuggingFace=20Dataset=20Column=20when=20sav=E2=80=A6=20(#10254)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/llamafactory/train/sft/trainer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 7726b3954..e3dacf798 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -214,8 +214,14 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): pad_len = np.nonzero(preds[i] != self.processing_class.pad_token_id)[0] if len(pad_len): # move pad token to last preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1) + + input_ids_column = dataset["input_ids"] + try: + input_ids_list = input_ids_column.to_pylist() + except AttributeError: + input_ids_list = list(input_ids_column) - decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False) + decoded_inputs = self.processing_class.batch_decode(input_ids_list, skip_special_tokens=False) decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens) decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens)