[misc] Support split eval_dataset when explict set "predict_with_generate" (#9604)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
ZIYI ZENG
2025-12-20 01:46:00 +08:00
committed by GitHub
parent ddd7dcc722
commit b0d49e137f
4 changed files with 37 additions and 30 deletions

View File

@@ -16,7 +16,7 @@ import os
from typing import TYPE_CHECKING, Literal, Optional, Union
import numpy as np
from datasets import Dataset, load_dataset, load_from_disk
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from ..extras import logging
from ..extras.constants import FILEEXT2TYPE
@@ -311,20 +311,22 @@ def get_dataset(
)
with training_args.main_process_first(desc="pre-process dataset", local=(not data_args.data_shared_file_system)):
dataset = _get_preprocessed_dataset(
dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False
)
if isinstance(eval_dataset, dict):
for eval_name, eval_data in eval_dataset.items():
eval_dataset[eval_name] = _get_preprocessed_dataset(
eval_data, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
)
else:
eval_dataset = _get_preprocessed_dataset(
eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
# move front to make sure eval_dataset(if contain or split) can preprocessed appropriately
train_dict, eval_dict = split_dataset(dataset, eval_dataset, data_args, seed=training_args.seed)
if "train" in train_dict:
train_dict["train"] = _get_preprocessed_dataset(
train_dict["train"], data_args, training_args, stage, template, tokenizer, processor, is_eval=False
)
dataset_dict = split_dataset(dataset, eval_dataset, data_args, seed=training_args.seed)
for key in eval_dict:
eval_dict[key] = _get_preprocessed_dataset(
eval_dict[key], data_args, training_args, stage, template, tokenizer, processor, is_eval=True
)
# Combine train and eval dictionaries
dataset_dict = DatasetDict({**train_dict, **eval_dict})
if data_args.tokenized_path is not None: # save tokenized dataset to disk
if training_args.should_save:
dataset_dict.save_to_disk(data_args.tokenized_path)