mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-08 04:35:58 +08:00
[train] fix compatibility issue with HuggingFace Dataset Column when sav… (#10254)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -215,7 +215,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
if len(pad_len): # move pad token to last
|
if len(pad_len): # move pad token to last
|
||||||
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
|
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
|
||||||
|
|
||||||
decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False)
|
input_ids_column = dataset["input_ids"]
|
||||||
|
try:
|
||||||
|
input_ids_list = input_ids_column.to_pylist()
|
||||||
|
except AttributeError:
|
||||||
|
input_ids_list = list(input_ids_column)
|
||||||
|
|
||||||
|
decoded_inputs = self.processing_class.batch_decode(input_ids_list, skip_special_tokens=False)
|
||||||
decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens)
|
decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens)
|
||||||
decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens)
|
decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user