Former-commit-id: 2f939b708f7f183f87aca67aa076db33a3c8a610
This commit is contained in:
hiyouga 2024-09-05 22:27:48 +08:00
parent 3aa6a3e45b
commit 5585713182
4 changed files with 9 additions and 8 deletions

View File

@ -38,7 +38,7 @@ jobs:
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
CI_OS: ${{ matrix.os }}
OS_NAME: ${{ matrix.os }}
steps:
- name: Checkout

View File

@ -158,7 +158,7 @@ def test_qwen_template():
_check_template("Qwen/Qwen2-7B-Instruct", "qwen", prompt_str, answer_str, extra_str="\n")
@pytest.mark.skip(reason="The fast tokenizer of Yi model is corrupted.")
@pytest.mark.xfail(reason="The fast tokenizer of Yi model is corrupted.")
def test_yi_template():
prompt_str = (
"<|im_start|>user\nHow are you<|im_end|>\n"

View File

@ -36,7 +36,6 @@ TRAIN_ARGS = {
"overwrite_output_dir": True,
"per_device_train_batch_size": 1,
"max_steps": 1,
"fp16": True,
}
INFER_ARGS = {
@ -48,18 +47,20 @@ INFER_ARGS = {
"export_dir": "llama3_export",
}
OS_NAME = os.environ.get("OS_NAME", "")
@pytest.mark.parametrize(
"stage,dataset",
[
("pt", "c4_demo"),
("sft", "alpaca_en_demo"),
("rm", "dpo_en_demo"),
pytest.param("rm", "dpo_en_demo", marks=pytest.mark.xfail(OS_NAME.startswith("windows"), reason="OS error.")),
("dpo", "dpo_en_demo"),
("kto", "kto_en_demo"),
],
)
def test_train(stage: str, dataset: str):
def test_run_exp(stage: str, dataset: str):
output_dir = "train_{}".format(stage)
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
assert os.path.exists(output_dir)

View File

@ -49,17 +49,17 @@ INFER_ARGS = {
"infer_dtype": "float16",
}
CI_OS = os.environ.get("CI_OS", "")
OS_NAME = os.environ.get("OS_NAME", "")
@pytest.mark.skipif(CI_OS.startswith("windows"), reason="Skip for windows.")
@pytest.mark.xfail(OS_NAME.startswith("windows"), reason="Known connection error on Windows.")
def test_pissa_train():
model = load_train_model(**TRAIN_ARGS)
ref_model = load_reference_model(TINY_LLAMA_PISSA, TINY_LLAMA_PISSA, use_pissa=True, is_trainable=True)
compare_model(model, ref_model)
@pytest.mark.skipif(CI_OS.startswith("windows"), reason="Skip for windows.")
@pytest.mark.xfail(OS_NAME.startswith("windows"), reason="Known connection error on Windows.")
def test_pissa_inference():
model = load_infer_model(**INFER_ARGS)
ref_model = load_reference_model(TINY_LLAMA_PISSA, TINY_LLAMA_PISSA, use_pissa=True, is_trainable=False)