mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
fix ci
Former-commit-id: 2f939b708f7f183f87aca67aa076db33a3c8a610
This commit is contained in:
parent
3aa6a3e45b
commit
5585713182
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@ -38,7 +38,7 @@ jobs:
|
||||
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
CI_OS: ${{ matrix.os }}
|
||||
OS_NAME: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
@ -158,7 +158,7 @@ def test_qwen_template():
|
||||
_check_template("Qwen/Qwen2-7B-Instruct", "qwen", prompt_str, answer_str, extra_str="\n")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="The fast tokenizer of Yi model is corrupted.")
|
||||
@pytest.mark.xfail(reason="The fast tokenizer of Yi model is corrupted.")
|
||||
def test_yi_template():
|
||||
prompt_str = (
|
||||
"<|im_start|>user\nHow are you<|im_end|>\n"
|
||||
|
@ -36,7 +36,6 @@ TRAIN_ARGS = {
|
||||
"overwrite_output_dir": True,
|
||||
"per_device_train_batch_size": 1,
|
||||
"max_steps": 1,
|
||||
"fp16": True,
|
||||
}
|
||||
|
||||
INFER_ARGS = {
|
||||
@ -48,18 +47,20 @@ INFER_ARGS = {
|
||||
"export_dir": "llama3_export",
|
||||
}
|
||||
|
||||
OS_NAME = os.environ.get("OS_NAME", "")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"stage,dataset",
|
||||
[
|
||||
("pt", "c4_demo"),
|
||||
("sft", "alpaca_en_demo"),
|
||||
("rm", "dpo_en_demo"),
|
||||
pytest.param("rm", "dpo_en_demo", marks=pytest.mark.xfail(OS_NAME.startswith("windows"), reason="OS error.")),
|
||||
("dpo", "dpo_en_demo"),
|
||||
("kto", "kto_en_demo"),
|
||||
],
|
||||
)
|
||||
def test_train(stage: str, dataset: str):
|
||||
def test_run_exp(stage: str, dataset: str):
|
||||
output_dir = "train_{}".format(stage)
|
||||
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
|
||||
assert os.path.exists(output_dir)
|
||||
|
@ -49,17 +49,17 @@ INFER_ARGS = {
|
||||
"infer_dtype": "float16",
|
||||
}
|
||||
|
||||
CI_OS = os.environ.get("CI_OS", "")
|
||||
OS_NAME = os.environ.get("OS_NAME", "")
|
||||
|
||||
|
||||
@pytest.mark.skipif(CI_OS.startswith("windows"), reason="Skip for windows.")
|
||||
@pytest.mark.xfail(OS_NAME.startswith("windows"), reason="Known connection error on Windows.")
|
||||
def test_pissa_train():
|
||||
model = load_train_model(**TRAIN_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA_PISSA, TINY_LLAMA_PISSA, use_pissa=True, is_trainable=True)
|
||||
compare_model(model, ref_model)
|
||||
|
||||
|
||||
@pytest.mark.skipif(CI_OS.startswith("windows"), reason="Skip for windows.")
|
||||
@pytest.mark.xfail(OS_NAME.startswith("windows"), reason="Known connection error on Windows.")
|
||||
def test_pissa_inference():
|
||||
model = load_infer_model(**INFER_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA_PISSA, TINY_LLAMA_PISSA, use_pissa=True, is_trainable=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user