This commit is contained in:
hiyouga
2023-08-01 18:43:53 +08:00
parent e6a3894b99
commit e3f80774c4
2 changed files with 12 additions and 12 deletions

View File

@@ -18,7 +18,7 @@ def preprocess_dataset(
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"]
) -> "Dataset":
column_names = list(dataset.column_names or [])
column_names = list(dataset.column_names)
template = get_template(data_args.template)
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
@@ -143,15 +143,19 @@ def preprocess_dataset(
if stage == "pt":
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_function = preprocess_pretrain_dataset
print_function = print_unsupervised_dataset_example
elif stage == "sft" and not training_args.predict_with_generate:
dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
preprocess_function = preprocess_supervised_dataset
print_function = print_supervised_dataset_example
elif stage == "rm":
dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1)
preprocess_function = preprocess_pairwise_dataset
print_function = print_pairwise_dataset_example
else:
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_function = preprocess_unsupervised_dataset
print_function = print_unsupervised_dataset_example
with training_args.main_process_first(desc="dataset map pre-processing"):
kwargs = {}
@@ -172,13 +176,5 @@ def preprocess_dataset(
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size)
if stage == "pt":
print_unsupervised_dataset_example(next(iter(dataset)))
elif stage == "sft":
print_supervised_dataset_example(next(iter(dataset)))
elif stage == "rm":
print_pairwise_dataset_example(next(iter(dataset)))
elif stage == "ppo":
print_unsupervised_dataset_example(next(iter(dataset)))
print_function(next(iter(dataset)))
return dataset