diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index b67781d7..c3af364c 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -62,8 +62,10 @@ def preprocess_dataset( model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} for query, response, history, system in construct_example(examples): - input_ids, labels = [], [] + if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""): + continue + input_ids, labels = [], [] for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( tokenizer, query, response, history, system )): @@ -106,6 +108,9 @@ def preprocess_dataset( model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} input_ids, labels = [], [] for query, response, history, system in construct_example(examples): + if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""): + continue + for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( tokenizer, query, response, history, system )): @@ -139,6 +144,9 @@ def preprocess_dataset( model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} for query, response, history, system in construct_example(examples): + if not (isinstance(query, str) and query != ""): + continue + input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system) if template.efficient_eos: @@ -158,7 +166,10 @@ def preprocess_dataset( def preprocess_pairwise_dataset(examples): # build input pairs with format ` X`, `Y1 ` and `Y2 ` model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []} - for query, response, history, system in construct_example(examples): + for query, response, history, system in construct_example(examples): + if not (isinstance(query, str) and isinstance(response, list) and query != "" and len(response) > 1): + continue + prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system) _, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system) @@ -203,19 +214,15 @@ def preprocess_dataset( print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) if stage == "pt": - dataset = dataset.filter(lambda example: example["prompt"]) preprocess_func = preprocess_pretrain_dataset print_function = print_unsupervised_dataset_example elif stage == "sft" and not training_args.predict_with_generate: - dataset = dataset.filter(lambda example: example["prompt"] and example["response"]) preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset print_function = print_supervised_dataset_example elif stage == "rm": - dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1) preprocess_func = preprocess_pairwise_dataset print_function = print_pairwise_dataset_example else: - dataset = dataset.filter(lambda example: example["prompt"]) preprocess_func = preprocess_unsupervised_dataset print_function = print_unsupervised_dataset_example @@ -235,9 +242,10 @@ def preprocess_dataset( **kwargs ) - try: - print_function(next(iter(dataset))) - except StopIteration: - raise ValueError("Empty dataset!") + if training_args.should_log: + try: + print_function(next(iter(dataset))) + except StopIteration: + raise ValueError("Empty dataset!") return dataset