From d0891f05fa7b7f5c73a434a6cf50165310108e94 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Fri, 19 Jul 2024 01:10:30 +0800 Subject: [PATCH] fix unittest Former-commit-id: e80006795fe6344ea98b61f9a8db16356498c7cb --- tests/data/processors/test_unsupervised.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/data/processors/test_unsupervised.py b/tests/data/processors/test_unsupervised.py index 976247c7..8713c772 100644 --- a/tests/data/processors/test_unsupervised.py +++ b/tests/data/processors/test_unsupervised.py @@ -33,6 +33,8 @@ TRAIN_ARGS = { "stage": "sft", "do_predict": True, "finetuning_type": "full", + "eval_dataset": "system_chat", + "dataset_dir": "REMOTE:" + DEMO_DATA, "template": "llama3", "cutoff_len": 8192, "overwrite_cache": True, @@ -45,7 +47,7 @@ TRAIN_ARGS = { @pytest.mark.parametrize("num_samples", [16]) def test_unsupervised_data(num_samples: int): - train_dataset = load_train_dataset(dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", **TRAIN_ARGS) + train_dataset = load_train_dataset(**TRAIN_ARGS) ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) original_data = load_dataset(DEMO_DATA, name="system_chat", split="train") indexes = random.choices(range(len(original_data)), k=num_samples)