mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
fix ci
Former-commit-id: b6a3fdd056d77dbe692053bc22a8923e24ed2256
This commit is contained in:
parent
cc02fb6180
commit
9f36534b49
3
.github/workflows/tests.yml
vendored
3
.github/workflows/tests.yml
vendored
@ -29,7 +29,7 @@ jobs:
|
||||
os:
|
||||
- "ubuntu-latest"
|
||||
- "windows-latest"
|
||||
- "macos-latest"
|
||||
- "macos-12"
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
@ -38,6 +38,7 @@ jobs:
|
||||
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
CI_OS: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
@ -14,6 +14,8 @@
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from llamafactory.train.test_utils import compare_model, load_infer_model, load_reference_model, load_train_model
|
||||
|
||||
|
||||
@ -47,6 +49,8 @@ INFER_ARGS = {
|
||||
"infer_dtype": "float16",
|
||||
}
|
||||
|
||||
CI_OS = os.environ.get("CI_OS", "")
|
||||
|
||||
|
||||
def test_pissa_train():
|
||||
model = load_train_model(**TRAIN_ARGS)
|
||||
@ -54,6 +58,7 @@ def test_pissa_train():
|
||||
compare_model(model, ref_model)
|
||||
|
||||
|
||||
@pytest.mark.skipif(CI_OS.startswith("windows"), reason="Skip for windows.")
|
||||
def test_pissa_inference():
|
||||
model = load_infer_model(**INFER_ARGS)
|
||||
ref_model = load_reference_model(TINY_LLAMA_PISSA, TINY_LLAMA_PISSA, use_pissa=True, is_trainable=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user