Former-commit-id: b5ffca5a190f3aed8ba8c49bd8cf3239fb787bf5
This commit is contained in:
hiyouga 2024-09-05 22:39:47 +08:00
parent b48b47d519
commit c5ef52a67a

View File

@ -32,7 +32,7 @@ TRAIN_ARGS = {
"dataset_dir": "REMOTE:" + DEMO_DATA,
"template": "llama3",
"cutoff_len": 1,
"overwrite_cache": True,
"overwrite_cache": False,
"overwrite_output_dir": True,
"per_device_train_batch_size": 1,
"max_steps": 1,
@ -55,9 +55,9 @@ OS_NAME = os.environ.get("OS_NAME", "")
[
("pt", "c4_demo"),
("sft", "alpaca_en_demo"),
pytest.param("rm", "dpo_en_demo", marks=pytest.mark.xfail(OS_NAME.startswith("windows"), reason="OS error.")),
("dpo", "dpo_en_demo"),
("kto", "kto_en_demo"),
pytest.param("rm", "dpo_en_demo", marks=pytest.mark.xfail(OS_NAME.startswith("windows"), reason="OS error.")),
],
)
def test_run_exp(stage: str, dataset: str):