Former-commit-id: eb5aa9adce7c01d453d45d2c901e530584e46eb6
This commit is contained in:
hiyouga 2024-09-05 22:39:47 +08:00
parent 5585713182
commit 52d3c42265

View File

@ -32,7 +32,7 @@ TRAIN_ARGS = {
"dataset_dir": "REMOTE:" + DEMO_DATA,
"template": "llama3",
"cutoff_len": 1,
"overwrite_cache": True,
"overwrite_cache": False,
"overwrite_output_dir": True,
"per_device_train_batch_size": 1,
"max_steps": 1,
@ -55,9 +55,9 @@ OS_NAME = os.environ.get("OS_NAME", "")
[
("pt", "c4_demo"),
("sft", "alpaca_en_demo"),
pytest.param("rm", "dpo_en_demo", marks=pytest.mark.xfail(OS_NAME.startswith("windows"), reason="OS error.")),
("dpo", "dpo_en_demo"),
("kto", "kto_en_demo"),
pytest.param("rm", "dpo_en_demo", marks=pytest.mark.xfail(OS_NAME.startswith("windows"), reason="OS error.")),
],
)
def test_run_exp(stage: str, dataset: str):