mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 04:10:36 +08:00
[test] add npu test yaml and add ascend a3 docker file (#9547)
Co-authored-by: jiaqiw09 <jiaqiw960714@gmail.com>
This commit is contained in:
@@ -179,6 +179,7 @@ def _check_plugin(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_base_plugin():
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA3)
|
||||
base_plugin = get_mm_plugin(name="base")
|
||||
@@ -186,6 +187,7 @@ def test_base_plugin():
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
@pytest.mark.skipif(not is_transformers_version_greater_than("4.50.0"), reason="Requires transformers>=4.50.0")
|
||||
def test_gemma3_plugin():
|
||||
@@ -208,6 +210,7 @@ def test_gemma3_plugin():
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0")
|
||||
def test_internvl_plugin():
|
||||
image_seqlen = 256
|
||||
@@ -226,6 +229,7 @@ def test_internvl_plugin():
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.skipif(not is_transformers_version_greater_than("4.51.0"), reason="Requires transformers>=4.51.0")
|
||||
def test_llama4_plugin():
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA4)
|
||||
@@ -247,6 +251,7 @@ def test_llama4_plugin():
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_llava_plugin():
|
||||
image_seqlen = 576
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
|
||||
@@ -260,6 +265,7 @@ def test_llava_plugin():
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_llava_next_plugin():
|
||||
image_seqlen = 1176
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-v1.6-vicuna-7b-hf")
|
||||
@@ -273,6 +279,7 @@ def test_llava_next_plugin():
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_llava_next_video_plugin():
|
||||
image_seqlen = 1176
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/LLaVA-NeXT-Video-7B-hf")
|
||||
@@ -286,6 +293,7 @@ def test_llava_next_video_plugin():
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
def test_paligemma_plugin():
|
||||
image_seqlen = 256
|
||||
@@ -305,6 +313,7 @@ def test_paligemma_plugin():
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.skipif(not is_transformers_version_greater_than("4.50.0"), reason="Requires transformers>=4.50.0")
|
||||
def test_pixtral_plugin():
|
||||
image_slice_height, image_slice_width = 2, 2
|
||||
@@ -327,6 +336,7 @@ def test_pixtral_plugin():
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0")
|
||||
def test_qwen2_omni_plugin():
|
||||
image_seqlen, audio_seqlen = 4, 2
|
||||
@@ -357,6 +367,7 @@ def test_qwen2_omni_plugin():
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_qwen2_vl_plugin():
|
||||
image_seqlen = 4
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
|
||||
@@ -373,6 +384,7 @@ def test_qwen2_vl_plugin():
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.skipif(not is_transformers_version_greater_than("4.57.0"), reason="Requires transformers>=4.57.0")
|
||||
def test_qwen3_vl_plugin():
|
||||
frame_seqlen = 1
|
||||
@@ -394,6 +406,7 @@ def test_qwen3_vl_plugin():
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.skipif(not is_transformers_version_greater_than("4.47.0"), reason="Requires transformers>=4.47.0")
|
||||
def test_video_llava_plugin():
|
||||
image_seqlen = 256
|
||||
|
||||
Reference in New Issue
Block a user