simplify code

This commit is contained in:
hiyouga
2023-07-20 15:08:57 +08:00
parent d1d8e8bae1
commit 67a2773074
18 changed files with 52 additions and 136 deletions

View File

@@ -0,0 +1,16 @@
from typing import Dict
from datasets import Dataset
def split_dataset(
dataset: Dataset, dev_ratio: float, do_train: bool
) -> Dict[str, Dataset]:
# Split the dataset
if do_train:
if dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=dev_ratio)
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
return {"train_dataset": dataset}
else: # do_eval or do_predict
return {"eval_dataset": dataset}