mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 04:02:49 +08:00
fix ci
Former-commit-id: 7f20e4722ae6ac907b36a3219dcd09d2ff5d071a
This commit is contained in:
parent
6a5e3816cf
commit
d3eb985bb6
6
.github/workflows/tests.yml
vendored
6
.github/workflows/tests.yml
vendored
@ -30,10 +30,10 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
python -m pip install .[torch,metrics,quality]
|
python -m pip install .[torch,dev]
|
||||||
- name: Check quality
|
- name: Check quality
|
||||||
run: |
|
run: |
|
||||||
make style && make quality
|
make style && make quality
|
||||||
|
|
||||||
pytest:
|
pytest:
|
||||||
needs: check_code_quality
|
needs: check_code_quality
|
||||||
@ -53,7 +53,7 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
python -m pip install .[torch,metrics,quality]
|
python -m pip install .[torch,dev]
|
||||||
- name: Test with pytest
|
- name: Test with pytest
|
||||||
run: |
|
run: |
|
||||||
make test
|
make test
|
||||||
|
2
setup.py
2
setup.py
@ -33,7 +33,7 @@ extra_require = {
|
|||||||
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
||||||
"qwen": ["transformers_stream_generator"],
|
"qwen": ["transformers_stream_generator"],
|
||||||
"modelscope": ["modelscope"],
|
"modelscope": ["modelscope"],
|
||||||
"quality": ["ruff"],
|
"dev": ["ruff", "pytest"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,13 +23,15 @@ def test_attention():
|
|||||||
"fa2": "LlamaFlashAttention2",
|
"fa2": "LlamaFlashAttention2",
|
||||||
}
|
}
|
||||||
for requested_attention in attention_available:
|
for requested_attention in attention_available:
|
||||||
model_args, _, finetuning_args, _ = get_infer_args({
|
model_args, _, finetuning_args, _ = get_infer_args(
|
||||||
"model_name_or_path": TINY_LLAMA,
|
{
|
||||||
"template": "llama2",
|
"model_name_or_path": TINY_LLAMA,
|
||||||
"flash_attn": requested_attention,
|
"template": "llama2",
|
||||||
})
|
"flash_attn": requested_attention,
|
||||||
|
}
|
||||||
|
)
|
||||||
tokenizer = load_tokenizer(model_args)
|
tokenizer = load_tokenizer(model_args)
|
||||||
model = load_model(tokenizer["tokenizer"], model_args, finetuning_args)
|
model = load_model(tokenizer["tokenizer"], model_args, finetuning_args)
|
||||||
for module in model.modules():
|
for module in model.modules():
|
||||||
if "Attention" in module.__class__.__name__:
|
if "Attention" in module.__class__.__name__:
|
||||||
assert module.__class__.__name__ == llama_attention_classes[requested_attention]
|
assert module.__class__.__name__ == llama_attention_classes[requested_attention]
|
Loading…
x
Reference in New Issue
Block a user