remove filter in preprocess

Former-commit-id: 2caf91f824320b226daa4666eda2da7cb853db9c
This commit is contained in:
hiyouga 2023-10-23 23:46:02 +08:00
parent daeff710eb
commit 84e27a1c0b

View File

@ -62,8 +62,10 @@ def preprocess_dataset(
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for query, response, history, system in construct_example(examples): for query, response, history, system in construct_example(examples):
input_ids, labels = [], [] if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
continue
input_ids, labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, query, response, history, system tokenizer, query, response, history, system
)): )):
@ -106,6 +108,9 @@ def preprocess_dataset(
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
input_ids, labels = [], [] input_ids, labels = [], []
for query, response, history, system in construct_example(examples): for query, response, history, system in construct_example(examples):
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
continue
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, query, response, history, system tokenizer, query, response, history, system
)): )):
@ -139,6 +144,9 @@ def preprocess_dataset(
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for query, response, history, system in construct_example(examples): for query, response, history, system in construct_example(examples):
if not (isinstance(query, str) and query != ""):
continue
input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system) input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system)
if template.efficient_eos: if template.efficient_eos:
@ -158,7 +166,10 @@ def preprocess_dataset(
def preprocess_pairwise_dataset(examples): def preprocess_pairwise_dataset(examples):
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>` # build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []} model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
for query, response, history, system in construct_example(examples): for query, response, history, system in construct_example(examples):
if not (isinstance(query, str) and isinstance(response, list) and query != "" and len(response) > 1):
continue
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system) prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system) _, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
@ -203,19 +214,15 @@ def preprocess_dataset(
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
if stage == "pt": if stage == "pt":
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_func = preprocess_pretrain_dataset preprocess_func = preprocess_pretrain_dataset
print_function = print_unsupervised_dataset_example print_function = print_unsupervised_dataset_example
elif stage == "sft" and not training_args.predict_with_generate: elif stage == "sft" and not training_args.predict_with_generate:
dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset
print_function = print_supervised_dataset_example print_function = print_supervised_dataset_example
elif stage == "rm": elif stage == "rm":
dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1)
preprocess_func = preprocess_pairwise_dataset preprocess_func = preprocess_pairwise_dataset
print_function = print_pairwise_dataset_example print_function = print_pairwise_dataset_example
else: else:
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_func = preprocess_unsupervised_dataset preprocess_func = preprocess_unsupervised_dataset
print_function = print_unsupervised_dataset_example print_function = print_unsupervised_dataset_example
@ -235,9 +242,10 @@ def preprocess_dataset(
**kwargs **kwargs
) )
try: if training_args.should_log:
print_function(next(iter(dataset))) try:
except StopIteration: print_function(next(iter(dataset)))
raise ValueError("Empty dataset!") except StopIteration:
raise ValueError("Empty dataset!")
return dataset return dataset