diff --git a/tests/e2e/test_train.py b/tests/e2e/test_train.py index 3e5d00d8..b997bc9b 100644 --- a/tests/e2e/test_train.py +++ b/tests/e2e/test_train.py @@ -32,7 +32,7 @@ TRAIN_ARGS = { "dataset_dir": "REMOTE:" + DEMO_DATA, "template": "llama3", "cutoff_len": 1, - "overwrite_cache": True, + "overwrite_cache": False, "overwrite_output_dir": True, "per_device_train_batch_size": 1, "max_steps": 1, @@ -55,9 +55,9 @@ OS_NAME = os.environ.get("OS_NAME", "") [ ("pt", "c4_demo"), ("sft", "alpaca_en_demo"), - pytest.param("rm", "dpo_en_demo", marks=pytest.mark.xfail(OS_NAME.startswith("windows"), reason="OS error.")), ("dpo", "dpo_en_demo"), ("kto", "kto_en_demo"), + pytest.param("rm", "dpo_en_demo", marks=pytest.mark.xfail(OS_NAME.startswith("windows"), reason="OS error.")), ], ) def test_run_exp(stage: str, dataset: str):