mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-05 21:22:50 +08:00
remove filter in preprocess
Former-commit-id: 2caf91f824320b226daa4666eda2da7cb853db9c
This commit is contained in:
parent
daeff710eb
commit
84e27a1c0b
@ -62,8 +62,10 @@ def preprocess_dataset(
|
|||||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
|
|
||||||
for query, response, history, system in construct_example(examples):
|
for query, response, history, system in construct_example(examples):
|
||||||
input_ids, labels = [], []
|
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
|
||||||
|
continue
|
||||||
|
|
||||||
|
input_ids, labels = [], []
|
||||||
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
||||||
tokenizer, query, response, history, system
|
tokenizer, query, response, history, system
|
||||||
)):
|
)):
|
||||||
@ -106,6 +108,9 @@ def preprocess_dataset(
|
|||||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
input_ids, labels = [], []
|
input_ids, labels = [], []
|
||||||
for query, response, history, system in construct_example(examples):
|
for query, response, history, system in construct_example(examples):
|
||||||
|
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
|
||||||
|
continue
|
||||||
|
|
||||||
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
||||||
tokenizer, query, response, history, system
|
tokenizer, query, response, history, system
|
||||||
)):
|
)):
|
||||||
@ -139,6 +144,9 @@ def preprocess_dataset(
|
|||||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
|
|
||||||
for query, response, history, system in construct_example(examples):
|
for query, response, history, system in construct_example(examples):
|
||||||
|
if not (isinstance(query, str) and query != ""):
|
||||||
|
continue
|
||||||
|
|
||||||
input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system)
|
input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system)
|
||||||
|
|
||||||
if template.efficient_eos:
|
if template.efficient_eos:
|
||||||
@ -158,7 +166,10 @@ def preprocess_dataset(
|
|||||||
def preprocess_pairwise_dataset(examples):
|
def preprocess_pairwise_dataset(examples):
|
||||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||||
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
|
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
|
||||||
for query, response, history, system in construct_example(examples):
|
for query, response, history, system in construct_example(examples):
|
||||||
|
if not (isinstance(query, str) and isinstance(response, list) and query != "" and len(response) > 1):
|
||||||
|
continue
|
||||||
|
|
||||||
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
|
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
|
||||||
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
|
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
|
||||||
|
|
||||||
@ -203,19 +214,15 @@ def preprocess_dataset(
|
|||||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||||
|
|
||||||
if stage == "pt":
|
if stage == "pt":
|
||||||
dataset = dataset.filter(lambda example: example["prompt"])
|
|
||||||
preprocess_func = preprocess_pretrain_dataset
|
preprocess_func = preprocess_pretrain_dataset
|
||||||
print_function = print_unsupervised_dataset_example
|
print_function = print_unsupervised_dataset_example
|
||||||
elif stage == "sft" and not training_args.predict_with_generate:
|
elif stage == "sft" and not training_args.predict_with_generate:
|
||||||
dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
|
|
||||||
preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset
|
preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset
|
||||||
print_function = print_supervised_dataset_example
|
print_function = print_supervised_dataset_example
|
||||||
elif stage == "rm":
|
elif stage == "rm":
|
||||||
dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1)
|
|
||||||
preprocess_func = preprocess_pairwise_dataset
|
preprocess_func = preprocess_pairwise_dataset
|
||||||
print_function = print_pairwise_dataset_example
|
print_function = print_pairwise_dataset_example
|
||||||
else:
|
else:
|
||||||
dataset = dataset.filter(lambda example: example["prompt"])
|
|
||||||
preprocess_func = preprocess_unsupervised_dataset
|
preprocess_func = preprocess_unsupervised_dataset
|
||||||
print_function = print_unsupervised_dataset_example
|
print_function = print_unsupervised_dataset_example
|
||||||
|
|
||||||
@ -235,9 +242,10 @@ def preprocess_dataset(
|
|||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
if training_args.should_log:
|
||||||
print_function(next(iter(dataset)))
|
try:
|
||||||
except StopIteration:
|
print_function(next(iter(dataset)))
|
||||||
raise ValueError("Empty dataset!")
|
except StopIteration:
|
||||||
|
raise ValueError("Empty dataset!")
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
Loading…
x
Reference in New Issue
Block a user