diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 7726b3954..e3dacf798 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -214,8 +214,14 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): pad_len = np.nonzero(preds[i] != self.processing_class.pad_token_id)[0] if len(pad_len): # move pad token to last preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1) + + input_ids_column = dataset["input_ids"] + try: + input_ids_list = input_ids_column.to_pylist() + except AttributeError: + input_ids_list = list(input_ids_column) - decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False) + decoded_inputs = self.processing_class.batch_decode(input_ids_list, skip_special_tokens=False) decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens) decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens)