mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-07 20:26:00 +08:00
[train] fix compatibility issue with HuggingFace Dataset Column when sav… (#10254)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -214,8 +214,14 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
pad_len = np.nonzero(preds[i] != self.processing_class.pad_token_id)[0]
|
||||
if len(pad_len): # move pad token to last
|
||||
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
|
||||
|
||||
input_ids_column = dataset["input_ids"]
|
||||
try:
|
||||
input_ids_list = input_ids_column.to_pylist()
|
||||
except AttributeError:
|
||||
input_ids_list = list(input_ids_column)
|
||||
|
||||
decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False)
|
||||
decoded_inputs = self.processing_class.batch_decode(input_ids_list, skip_special_tokens=False)
|
||||
decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens)
|
||||
decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user