From d3eb985bb62382bac38dad6d6d01ac1ea4e57490 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Sat, 8 Jun 2024 01:57:36 +0800 Subject: [PATCH] fix ci Former-commit-id: 7f20e4722ae6ac907b36a3219dcd09d2ff5d071a --- .github/workflows/tests.yml | 6 +++--- setup.py | 2 +- tests/model/{test_attn.py => test_attention.py} | 14 ++++++++------ 3 files changed, 12 insertions(+), 10 deletions(-) rename tests/model/{test_attn.py => test_attention.py} (73%) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a8246986..a66b579b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -30,10 +30,10 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install .[torch,metrics,quality] + python -m pip install .[torch,dev] - name: Check quality run: | - make style && make quality + make style && make quality pytest: needs: check_code_quality @@ -53,7 +53,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install .[torch,metrics,quality] + python -m pip install .[torch,dev] - name: Test with pytest run: | make test diff --git a/setup.py b/setup.py index c32be8af..405ac46e 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ extra_require = { "aqlm": ["aqlm[gpu]>=1.1.0"], "qwen": ["transformers_stream_generator"], "modelscope": ["modelscope"], - "quality": ["ruff"], + "dev": ["ruff", "pytest"], } diff --git a/tests/model/test_attn.py b/tests/model/test_attention.py similarity index 73% rename from tests/model/test_attn.py rename to tests/model/test_attention.py index 12d920ef..6dd46050 100644 --- a/tests/model/test_attn.py +++ b/tests/model/test_attention.py @@ -23,13 +23,15 @@ def test_attention(): "fa2": "LlamaFlashAttention2", } for requested_attention in attention_available: - model_args, _, finetuning_args, _ = get_infer_args({ - "model_name_or_path": TINY_LLAMA, - "template": "llama2", - "flash_attn": requested_attention, - }) + model_args, _, finetuning_args, _ = get_infer_args( + { + "model_name_or_path": TINY_LLAMA, + "template": "llama2", + "flash_attn": requested_attention, + } + ) tokenizer = load_tokenizer(model_args) model = load_model(tokenizer["tokenizer"], model_args, finetuning_args) for module in model.modules(): if "Attention" in module.__class__.__name__: - assert module.__class__.__name__ == llama_attention_classes[requested_attention] + assert module.__class__.__name__ == llama_attention_classes[requested_attention]