Former-commit-id: 2f939b708f7f183f87aca67aa076db33a3c8a610
This commit is contained in:
hiyouga 2024-09-05 22:27:48 +08:00
parent 3aa6a3e45b
commit 5585713182
4 changed files with 9 additions and 8 deletions

View File

@ -38,7 +38,7 @@ jobs:
env: env:
HF_TOKEN: ${{ secrets.HF_TOKEN }} HF_TOKEN: ${{ secrets.HF_TOKEN }}
CI_OS: ${{ matrix.os }} OS_NAME: ${{ matrix.os }}
steps: steps:
- name: Checkout - name: Checkout

View File

@ -158,7 +158,7 @@ def test_qwen_template():
_check_template("Qwen/Qwen2-7B-Instruct", "qwen", prompt_str, answer_str, extra_str="\n") _check_template("Qwen/Qwen2-7B-Instruct", "qwen", prompt_str, answer_str, extra_str="\n")
@pytest.mark.skip(reason="The fast tokenizer of Yi model is corrupted.") @pytest.mark.xfail(reason="The fast tokenizer of Yi model is corrupted.")
def test_yi_template(): def test_yi_template():
prompt_str = ( prompt_str = (
"<|im_start|>user\nHow are you<|im_end|>\n" "<|im_start|>user\nHow are you<|im_end|>\n"

View File

@ -36,7 +36,6 @@ TRAIN_ARGS = {
"overwrite_output_dir": True, "overwrite_output_dir": True,
"per_device_train_batch_size": 1, "per_device_train_batch_size": 1,
"max_steps": 1, "max_steps": 1,
"fp16": True,
} }
INFER_ARGS = { INFER_ARGS = {
@ -48,18 +47,20 @@ INFER_ARGS = {
"export_dir": "llama3_export", "export_dir": "llama3_export",
} }
OS_NAME = os.environ.get("OS_NAME", "")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"stage,dataset", "stage,dataset",
[ [
("pt", "c4_demo"), ("pt", "c4_demo"),
("sft", "alpaca_en_demo"), ("sft", "alpaca_en_demo"),
("rm", "dpo_en_demo"), pytest.param("rm", "dpo_en_demo", marks=pytest.mark.xfail(OS_NAME.startswith("windows"), reason="OS error.")),
("dpo", "dpo_en_demo"), ("dpo", "dpo_en_demo"),
("kto", "kto_en_demo"), ("kto", "kto_en_demo"),
], ],
) )
def test_train(stage: str, dataset: str): def test_run_exp(stage: str, dataset: str):
output_dir = "train_{}".format(stage) output_dir = "train_{}".format(stage)
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS}) run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
assert os.path.exists(output_dir) assert os.path.exists(output_dir)

View File

@ -49,17 +49,17 @@ INFER_ARGS = {
"infer_dtype": "float16", "infer_dtype": "float16",
} }
CI_OS = os.environ.get("CI_OS", "") OS_NAME = os.environ.get("OS_NAME", "")
@pytest.mark.skipif(CI_OS.startswith("windows"), reason="Skip for windows.") @pytest.mark.xfail(OS_NAME.startswith("windows"), reason="Known connection error on Windows.")
def test_pissa_train(): def test_pissa_train():
model = load_train_model(**TRAIN_ARGS) model = load_train_model(**TRAIN_ARGS)
ref_model = load_reference_model(TINY_LLAMA_PISSA, TINY_LLAMA_PISSA, use_pissa=True, is_trainable=True) ref_model = load_reference_model(TINY_LLAMA_PISSA, TINY_LLAMA_PISSA, use_pissa=True, is_trainable=True)
compare_model(model, ref_model) compare_model(model, ref_model)
@pytest.mark.skipif(CI_OS.startswith("windows"), reason="Skip for windows.") @pytest.mark.xfail(OS_NAME.startswith("windows"), reason="Known connection error on Windows.")
def test_pissa_inference(): def test_pissa_inference():
model = load_infer_model(**INFER_ARGS) model = load_infer_model(**INFER_ARGS)
ref_model = load_reference_model(TINY_LLAMA_PISSA, TINY_LLAMA_PISSA, use_pissa=True, is_trainable=False) ref_model = load_reference_model(TINY_LLAMA_PISSA, TINY_LLAMA_PISSA, use_pissa=True, is_trainable=False)