From 7b01c0676c36f2170b43787f9a0acede157c0b85 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Thu, 5 Sep 2024 02:49:22 +0800 Subject: [PATCH] fix ci Former-commit-id: 7899b44b19c3d0a70706d987bb7d2e0e3536014b --- .github/workflows/tests.yml | 3 ++- tests/model/test_pissa.py | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 139e6154..2457241a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -29,7 +29,7 @@ jobs: os: - "ubuntu-latest" - "windows-latest" - - "macos-latest" + - "macos-12" runs-on: ${{ matrix.os }} @@ -38,6 +38,7 @@ jobs: env: HF_TOKEN: ${{ secrets.HF_TOKEN }} + CI_OS: ${{ matrix.os }} steps: - name: Checkout diff --git a/tests/model/test_pissa.py b/tests/model/test_pissa.py index 5e35fcce..26340c3b 100644 --- a/tests/model/test_pissa.py +++ b/tests/model/test_pissa.py @@ -14,6 +14,8 @@ import os +import pytest + from llamafactory.train.test_utils import compare_model, load_infer_model, load_reference_model, load_train_model @@ -47,6 +49,8 @@ INFER_ARGS = { "infer_dtype": "float16", } +CI_OS = os.environ.get("CI_OS", "") + def test_pissa_train(): model = load_train_model(**TRAIN_ARGS) @@ -54,6 +58,7 @@ def test_pissa_train(): compare_model(model, ref_model) +@pytest.mark.skipif(CI_OS.startswith("windows"), reason="Skip for windows.") def test_pissa_inference(): model = load_infer_model(**INFER_ARGS) ref_model = load_reference_model(TINY_LLAMA_PISSA, TINY_LLAMA_PISSA, use_pissa=True, is_trainable=False)