[misc] fix accelerator (#9661)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Yaowei Zheng
2025-12-25 02:11:04 +08:00
committed by GitHub
parent 6a2eafbae3
commit a754604c11
44 changed files with 396 additions and 448 deletions

View File

@@ -179,7 +179,7 @@ def _check_plugin(
)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_base_plugin():
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA3)
base_plugin = get_mm_plugin(name="base")
@@ -187,7 +187,7 @@ def test_base_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.skipif(not is_transformers_version_greater_than("4.50.0"), reason="Requires transformers>=4.50.0")
def test_gemma3_plugin():
@@ -210,7 +210,7 @@ def test_gemma3_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0")
def test_internvl_plugin():
image_seqlen = 256
@@ -229,7 +229,7 @@ def test_internvl_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.51.0"), reason="Requires transformers>=4.51.0")
def test_llama4_plugin():
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA4)
@@ -251,7 +251,7 @@ def test_llama4_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_llava_plugin():
image_seqlen = 576
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
@@ -265,7 +265,7 @@ def test_llava_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_llava_next_plugin():
image_seqlen = 1176
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-v1.6-vicuna-7b-hf")
@@ -279,7 +279,7 @@ def test_llava_next_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_llava_next_video_plugin():
image_seqlen = 1176
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/LLaVA-NeXT-Video-7B-hf")
@@ -293,7 +293,7 @@ def test_llava_next_video_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_paligemma_plugin():
image_seqlen = 256
@@ -313,7 +313,7 @@ def test_paligemma_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.50.0"), reason="Requires transformers>=4.50.0")
def test_pixtral_plugin():
image_slice_height, image_slice_width = 2, 2
@@ -336,7 +336,7 @@ def test_pixtral_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0")
def test_qwen2_omni_plugin():
image_seqlen, audio_seqlen = 4, 2
@@ -367,7 +367,7 @@ def test_qwen2_omni_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_qwen2_vl_plugin():
image_seqlen = 4
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
@@ -384,7 +384,7 @@ def test_qwen2_vl_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.57.0"), reason="Requires transformers>=4.57.0")
def test_qwen3_vl_plugin():
frame_seqlen = 1
@@ -406,7 +406,7 @@ def test_qwen3_vl_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.47.0"), reason="Requires transformers>=4.47.0")
def test_video_llava_plugin():
image_seqlen = 256