[train] fix compatibility issue with HuggingFace Dataset Column when sav… (#10254)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
pyx
2026-03-06 18:44:57 +08:00
committed by GitHub
parent 0ee1c42c2b
commit 9501c3308a

View File

@@ -214,8 +214,14 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
pad_len = np.nonzero(preds[i] != self.processing_class.pad_token_id)[0] pad_len = np.nonzero(preds[i] != self.processing_class.pad_token_id)[0]
if len(pad_len): # move pad token to last if len(pad_len): # move pad token to last
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1) preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
input_ids_column = dataset["input_ids"]
try:
input_ids_list = input_ids_column.to_pylist()
except AttributeError:
input_ids_list = list(input_ids_column)
decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False) decoded_inputs = self.processing_class.batch_decode(input_ids_list, skip_special_tokens=False)
decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens) decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens)
decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens) decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens)