diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index 393366e6..c42b2047 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -140,9 +140,9 @@ def preprocess_dataset( print("input_ids:\n{}".format(example["input_ids"])) print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("label_ids:\n{}".format(example["labels"])) - print("labels:\n{}".format(tokenizer.decode([ - token_id if token_id != IGNORE_INDEX else tokenizer.pad_token_id for token_id in example["labels"] - ], skip_special_tokens=False))) + print("labels:\n{}".format( + tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False) + )) def print_pairwise_dataset_example(example): print("prompt_ids:\n{}".format(example["prompt_ids"]))