[misc] fix accelerator (#9661)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Yaowei Zheng
2025-12-25 02:11:04 +08:00
committed by GitHub
parent 6a2eafbae3
commit a754604c11
44 changed files with 396 additions and 448 deletions

View File

@@ -18,19 +18,17 @@ Contains shared fixtures, pytest configuration, and custom markers.
"""
import os
from typing import Optional
import pytest
from pytest import Config, Item
from pytest import Config, FixtureRequest, Item, MonkeyPatch
from llamafactory.extras.misc import get_current_device, get_device_count, is_env_enabled
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.train.test_utils import patch_valuehead_model
try:
CURRENT_DEVICE = get_current_device().type # cpu | cuda | npu
except Exception:
CURRENT_DEVICE = "cpu"
CURRENT_DEVICE = get_current_device().type
def pytest_configure(config: Config):
@@ -66,26 +64,27 @@ def _handle_runs_on(items: list[Item]):
def _handle_slow_tests(items: list[Item]):
"""Skip slow tests unless RUN_SLOW is enabled."""
if not is_env_enabled("RUN_SLOW", "0"):
if not is_env_enabled("RUN_SLOW"):
skip_slow = pytest.mark.skip(reason="slow test (set RUN_SLOW=1 to run)")
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)
def _get_visible_devices_env():
def _get_visible_devices_env() -> Optional[str]:
"""Return device visibility env var name."""
if CURRENT_DEVICE == "cuda":
return "CUDA_VISIBLE_DEVICES"
if CURRENT_DEVICE == "npu":
elif CURRENT_DEVICE == "npu":
return "ASCEND_RT_VISIBLE_DEVICES"
return None
else:
return None
def _handle_device_visibility(items: list[Item]):
"""Handle device visibility based on test markers."""
env_key = _get_visible_devices_env()
if env_key is None or CURRENT_DEVICE == "cpu":
if env_key is None or CURRENT_DEVICE in ("cpu", "mps"):
return
# Parse visible devices
@@ -121,7 +120,7 @@ def pytest_collection_modifyitems(config: Config, items: list[Item]):
@pytest.fixture(autouse=True)
def _manage_distributed_env(request, monkeypatch):
def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None:
"""Set environment variables for distributed tests if specific devices are requested."""
env_key = _get_visible_devices_env()
if not env_key:
@@ -131,8 +130,7 @@ def _manage_distributed_env(request, monkeypatch):
old_value = os.environ.get(env_key)
marker = request.node.get_closest_marker("require_distributed")
if marker:
# Distributed test
if marker: # distributed test
required = marker.args[0] if marker.args else 2
specific_devices = marker.args[1] if len(marker.args) > 1 else None
@@ -142,8 +140,7 @@ def _manage_distributed_env(request, monkeypatch):
devices_str = ",".join(str(i) for i in range(required))
monkeypatch.setenv(env_key, devices_str)
else:
# Non-distributed test
else: # non-distributed test
if old_value:
visible_devices = [v for v in old_value.split(",") if v != ""]
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")

View File

@@ -42,7 +42,7 @@ TRAIN_ARGS = {
}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [16])
def test_feedback_data(num_samples: int):
train_dataset = load_dataset_module(**TRAIN_ARGS)["train_dataset"]

View File

@@ -51,7 +51,7 @@ def _convert_sharegpt_to_openai(messages: list[dict[str, str]]) -> list[dict[str
return new_messages
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [16])
def test_pairwise_data(num_samples: int):
train_dataset = load_dataset_module(**TRAIN_ARGS)["train_dataset"]

View File

@@ -18,7 +18,7 @@ import pytest
from llamafactory.data.processor.processor_utils import infer_seqlen
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize(
"test_input,test_output",
[

View File

@@ -42,7 +42,7 @@ TRAIN_ARGS = {
}
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [16])
def test_supervised_single_turn(num_samples: int):
train_dataset = load_dataset_module(dataset_dir="ONLINE", dataset=TINY_DATA, **TRAIN_ARGS)["train_dataset"]
@@ -62,7 +62,7 @@ def test_supervised_single_turn(num_samples: int):
assert train_dataset["input_ids"][index] == ref_input_ids
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [8])
def test_supervised_multi_turn(num_samples: int):
train_dataset = load_dataset_module(dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", **TRAIN_ARGS)[
@@ -76,7 +76,7 @@ def test_supervised_multi_turn(num_samples: int):
assert train_dataset["input_ids"][index] == ref_input_ids
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [4])
def test_supervised_train_on_prompt(num_samples: int):
train_dataset = load_dataset_module(
@@ -91,7 +91,7 @@ def test_supervised_train_on_prompt(num_samples: int):
assert train_dataset["labels"][index] == ref_ids
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [4])
def test_supervised_mask_history(num_samples: int):
train_dataset = load_dataset_module(

View File

@@ -46,7 +46,7 @@ TRAIN_ARGS = {
}
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("num_samples", [16])
def test_unsupervised_data(num_samples: int):
train_dataset = load_dataset_module(**TRAIN_ARGS)["train_dataset"]

View File

@@ -29,7 +29,7 @@ from llamafactory.model import load_tokenizer
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_base_collator():
model_args, data_args, *_ = get_infer_args({"model_name_or_path": TINY_LLAMA3, "template": "default"})
tokenizer_module = load_tokenizer(model_args)
@@ -73,7 +73,7 @@ def test_base_collator():
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_multimodal_collator():
model_args, data_args, *_ = get_infer_args(
{"model_name_or_path": "Qwen/Qwen2-VL-2B-Instruct", "template": "qwen2_vl"}

View File

@@ -20,7 +20,7 @@ from llamafactory.data.parser import DatasetAttr
from llamafactory.hparams import DataArguments
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_alpaca_converter():
dataset_attr = DatasetAttr("hf_hub", "llamafactory/tiny-supervised-dataset")
data_args = DataArguments()
@@ -41,7 +41,7 @@ def test_alpaca_converter():
}
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_sharegpt_converter():
dataset_attr = DatasetAttr("hf_hub", "llamafactory/tiny-supervised-dataset")
data_args = DataArguments()

View File

@@ -38,19 +38,19 @@ TOOLS = [
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_empty_formatter():
formatter = EmptyFormatter(slots=["\n"])
assert formatter.apply() == ["\n"]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_string_formatter():
formatter = StringFormatter(slots=["<s>", "Human: {{content}}\nAssistant:"])
assert formatter.apply(content="Hi") == ["<s>", "Human: Hi\nAssistant:"]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
tool_calls = json.dumps(FUNCTION)
@@ -60,7 +60,7 @@ def test_function_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
tool_calls = json.dumps([FUNCTION] * 2)
@@ -71,7 +71,7 @@ def test_multi_function_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_default_tool_formatter():
formatter = ToolFormatter(tool_format="default")
assert formatter.apply(content=json.dumps(TOOLS)) == [
@@ -90,14 +90,14 @@ def test_default_tool_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_default_tool_extractor():
formatter = ToolFormatter(tool_format="default")
result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_default_multi_tool_extractor():
formatter = ToolFormatter(tool_format="default")
result = (
@@ -110,14 +110,14 @@ def test_default_multi_tool_extractor():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_glm4_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}"], tool_format="glm4")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == ["""tool_name\n{"foo": "bar", "size": 10}"""]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_glm4_tool_formatter():
formatter = ToolFormatter(tool_format="glm4")
assert formatter.apply(content=json.dumps(TOOLS)) == [
@@ -128,14 +128,14 @@ def test_glm4_tool_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_glm4_tool_extractor():
formatter = ToolFormatter(tool_format="glm4")
result = """test_tool\n{"foo": "bar", "size": 10}\n"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_llama3_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
tool_calls = json.dumps(FUNCTION)
@@ -144,7 +144,7 @@ def test_llama3_function_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_llama3_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
tool_calls = json.dumps([FUNCTION] * 2)
@@ -155,7 +155,7 @@ def test_llama3_multi_function_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_llama3_tool_formatter():
formatter = ToolFormatter(tool_format="llama3")
date = datetime.now().strftime("%d %b %Y")
@@ -169,14 +169,14 @@ def test_llama3_tool_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_llama3_tool_extractor():
formatter = ToolFormatter(tool_format="llama3")
result = """{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}\n"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_llama3_multi_tool_extractor():
formatter = ToolFormatter(tool_format="llama3")
result = (
@@ -189,7 +189,7 @@ def test_llama3_multi_tool_extractor():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_mistral_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps(FUNCTION)
@@ -199,7 +199,7 @@ def test_mistral_function_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_mistral_multi_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps([FUNCTION] * 2)
@@ -211,7 +211,7 @@ def test_mistral_multi_function_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_mistral_tool_formatter():
formatter = ToolFormatter(tool_format="mistral")
wrapped_tool = {"type": "function", "function": TOOLS[0]}
@@ -220,14 +220,14 @@ def test_mistral_tool_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_mistral_tool_extractor():
formatter = ToolFormatter(tool_format="mistral")
result = """{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_mistral_multi_tool_extractor():
formatter = ToolFormatter(tool_format="mistral")
result = (
@@ -240,7 +240,7 @@ def test_mistral_multi_tool_extractor():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_qwen_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen")
tool_calls = json.dumps(FUNCTION)
@@ -249,7 +249,7 @@ def test_qwen_function_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_qwen_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen")
tool_calls = json.dumps([FUNCTION] * 2)
@@ -260,7 +260,7 @@ def test_qwen_multi_function_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_qwen_tool_formatter():
formatter = ToolFormatter(tool_format="qwen")
wrapped_tool = {"type": "function", "function": TOOLS[0]}
@@ -274,14 +274,14 @@ def test_qwen_tool_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_qwen_tool_extractor():
formatter = ToolFormatter(tool_format="qwen")
result = """<tool_call>\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_qwen_multi_tool_extractor():
formatter = ToolFormatter(tool_format="qwen")
result = (

View File

@@ -40,21 +40,21 @@ TRAIN_ARGS = {
}
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_load_train_only():
dataset_module = load_dataset_module(**TRAIN_ARGS)
assert dataset_module.get("train_dataset") is not None
assert dataset_module.get("eval_dataset") is None
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_load_val_size():
dataset_module = load_dataset_module(val_size=0.1, **TRAIN_ARGS)
assert dataset_module.get("train_dataset") is not None
assert dataset_module.get("eval_dataset") is not None
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_load_eval_data():
dataset_module = load_dataset_module(eval_dataset=TINY_DATA, **TRAIN_ARGS)
assert dataset_module.get("train_dataset") is not None

View File

@@ -179,7 +179,7 @@ def _check_plugin(
)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_base_plugin():
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA3)
base_plugin = get_mm_plugin(name="base")
@@ -187,7 +187,7 @@ def test_base_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.skipif(not is_transformers_version_greater_than("4.50.0"), reason="Requires transformers>=4.50.0")
def test_gemma3_plugin():
@@ -210,7 +210,7 @@ def test_gemma3_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0")
def test_internvl_plugin():
image_seqlen = 256
@@ -229,7 +229,7 @@ def test_internvl_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.51.0"), reason="Requires transformers>=4.51.0")
def test_llama4_plugin():
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA4)
@@ -251,7 +251,7 @@ def test_llama4_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_llava_plugin():
image_seqlen = 576
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
@@ -265,7 +265,7 @@ def test_llava_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_llava_next_plugin():
image_seqlen = 1176
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-v1.6-vicuna-7b-hf")
@@ -279,7 +279,7 @@ def test_llava_next_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_llava_next_video_plugin():
image_seqlen = 1176
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/LLaVA-NeXT-Video-7B-hf")
@@ -293,7 +293,7 @@ def test_llava_next_video_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_paligemma_plugin():
image_seqlen = 256
@@ -313,7 +313,7 @@ def test_paligemma_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.50.0"), reason="Requires transformers>=4.50.0")
def test_pixtral_plugin():
image_slice_height, image_slice_width = 2, 2
@@ -336,7 +336,7 @@ def test_pixtral_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0")
def test_qwen2_omni_plugin():
image_seqlen, audio_seqlen = 4, 2
@@ -367,7 +367,7 @@ def test_qwen2_omni_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_qwen2_vl_plugin():
image_seqlen = 4
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
@@ -384,7 +384,7 @@ def test_qwen2_vl_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.57.0"), reason="Requires transformers>=4.57.0")
def test_qwen3_vl_plugin():
frame_seqlen = 1
@@ -406,7 +406,7 @@ def test_qwen3_vl_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.47.0"), reason="Requires transformers>=4.47.0")
def test_video_llava_plugin():
image_seqlen = 256

View File

@@ -89,7 +89,7 @@ def _check_template(
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False])
def test_encode_oneturn(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
@@ -105,7 +105,7 @@ def test_encode_oneturn(use_fast: bool):
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False])
def test_encode_multiturn(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
@@ -127,7 +127,7 @@ def test_encode_multiturn(use_fast: bool):
)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False, None])
@@ -154,7 +154,7 @@ def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thi
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False, None])
@@ -184,7 +184,7 @@ def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_t
)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False])
def test_jinja_template(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
@@ -195,7 +195,7 @@ def test_jinja_template(use_fast: bool):
assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_ollama_modelfile():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
@@ -213,14 +213,14 @@ def test_ollama_modelfile():
)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_get_stop_token_ids():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
assert set(template.get_stop_token_ids(tokenizer)) == {128008, 128009}
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.parametrize("use_fast", [True, False])
def test_gemma_template(use_fast: bool):
@@ -234,7 +234,7 @@ def test_gemma_template(use_fast: bool):
_check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.parametrize("use_fast", [True, False])
def test_gemma2_template(use_fast: bool):
@@ -248,7 +248,7 @@ def test_gemma2_template(use_fast: bool):
_check_template("google/gemma-2-2b-it", "gemma2", prompt_str, answer_str, use_fast)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.parametrize("use_fast", [True, False])
def test_llama3_template(use_fast: bool):
@@ -262,7 +262,7 @@ def test_llama3_template(use_fast: bool):
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize(
"use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Llama 4 has no slow tokenizer."))]
)
@@ -284,7 +284,7 @@ def test_llama4_template(use_fast: bool):
pytest.param(False, marks=pytest.mark.xfail(reason="Phi-4 slow tokenizer is broken.")),
],
)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_phi4_template(use_fast: bool):
prompt_str = (
f"<|im_start|>user<|im_sep|>{MESSAGES[0]['content']}<|im_end|>"
@@ -296,7 +296,7 @@ def test_phi4_template(use_fast: bool):
_check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
@pytest.mark.parametrize("use_fast", [True, False])
def test_qwen2_5_template(use_fast: bool):
@@ -311,7 +311,7 @@ def test_qwen2_5_template(use_fast: bool):
_check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False])
def test_qwen3_template(use_fast: bool, cot_messages: bool):
@@ -331,7 +331,7 @@ def test_qwen3_template(use_fast: bool, cot_messages: bool):
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=messages)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_parse_llama3_template():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, token=HF_TOKEN)
template = parse_template(tokenizer)
@@ -345,7 +345,7 @@ def test_parse_llama3_template():
assert template.default_system == ""
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
def test_parse_qwen_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", token=HF_TOKEN)
@@ -358,7 +358,7 @@ def test_parse_qwen_template():
assert template.default_system == "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
def test_parse_qwen3_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN)

View File

@@ -37,13 +37,13 @@ MESSAGES = [
EXPECTED_RESPONSE = "_rho"
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_chat():
chat_model = ChatModel(INFER_ARGS)
assert chat_model.chat(MESSAGES)[0].response_text == EXPECTED_RESPONSE
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_stream_chat():
chat_model = ChatModel(INFER_ARGS)
response = ""

View File

@@ -39,7 +39,7 @@ MESSAGES = [
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cuda"])
@pytest.mark.skipif(not is_sglang_available(), reason="SGLang is not installed")
def test_chat():
r"""Test the SGLang engine's basic chat functionality."""
@@ -49,7 +49,7 @@ def test_chat():
print(response.response_text)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cuda"])
@pytest.mark.skipif(not is_sglang_available(), reason="SGLang is not installed")
def test_stream_chat():
r"""Test the SGLang engine's streaming chat functionality."""

View File

@@ -49,7 +49,7 @@ INFER_ARGS = {
OS_NAME = os.getenv("OS_NAME", "")
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize(
"stage,dataset",
[
@@ -66,7 +66,7 @@ def test_run_exp(stage: str, dataset: str):
assert os.path.exists(output_dir)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_export():
export_dir = os.path.join("output", "llama3_export")
export_model({"export_dir": export_dir, **INFER_ARGS})

View File

@@ -17,7 +17,7 @@ import pytest
from llamafactory.eval.template import get_eval_template
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_eval_template_en():
support_set = [
{
@@ -56,7 +56,7 @@ def test_eval_template_en():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_eval_template_zh():
support_set = [
{

View File

@@ -25,7 +25,6 @@ TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
UNUSED_TOKEN = "<|UNUSED_TOKEN|>"
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.parametrize("special_tokens", [False, True])
def test_add_tokens(special_tokens: bool):
if special_tokens:

View File

@@ -39,7 +39,6 @@ INFER_ARGS = {
}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.xfail(is_transformers_version_greater_than("4.48"), reason="Attention refactor.")
def test_attention():
attention_available = ["disabled"]

View File

@@ -39,7 +39,6 @@ TRAIN_ARGS = {
}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.parametrize("disable_gradient_checkpointing", [False, True])
def test_vanilla_checkpointing(disable_gradient_checkpointing: bool):
model = load_train_model(disable_gradient_checkpointing=disable_gradient_checkpointing, **TRAIN_ARGS)
@@ -47,14 +46,12 @@ def test_vanilla_checkpointing(disable_gradient_checkpointing: bool):
assert getattr(module, "gradient_checkpointing") != disable_gradient_checkpointing
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_unsloth_gradient_checkpointing():
model = load_train_model(use_unsloth_gc=True, **TRAIN_ARGS)
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing"
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_upcast_layernorm():
model = load_train_model(upcast_layernorm=True, **TRAIN_ARGS)
for name, param in model.named_parameters():
@@ -62,7 +59,6 @@ def test_upcast_layernorm():
assert param.dtype == torch.float32
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_upcast_lmhead_output():
model = load_train_model(upcast_lmhead_output=True, **TRAIN_ARGS)
inputs = torch.randn((1, 16), dtype=torch.float16, device=get_current_device())

View File

@@ -24,7 +24,6 @@ from llamafactory.model.model_utils.misc import find_expanded_modules
HF_TOKEN = os.getenv("HF_TOKEN")
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_expanded_modules():
config = AutoConfig.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")

View File

@@ -18,7 +18,6 @@ import torch
from llamafactory.model.model_utils.packing import get_seqlens_in_batch, get_unpad_data
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.parametrize(
"attention_mask,golden_seq_lens",
[

View File

@@ -23,7 +23,6 @@ from llamafactory.hparams import FinetuningArguments, ModelArguments
from llamafactory.model.adapter import init_adapter
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.parametrize("freeze_vision_tower", (False, True))
@pytest.mark.parametrize("freeze_multi_modal_projector", (False, True))
@pytest.mark.parametrize("freeze_language_model", (False, True))
@@ -49,7 +48,6 @@ def test_visual_full(freeze_vision_tower: bool, freeze_multi_modal_projector: bo
assert param.requires_grad != freeze_language_model
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.parametrize("freeze_vision_tower,freeze_language_model", ((False, False), (False, True), (True, False)))
def test_visual_lora(freeze_vision_tower: bool, freeze_language_model: bool):
model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct")
@@ -82,7 +80,6 @@ def test_visual_lora(freeze_vision_tower: bool, freeze_language_model: bool):
assert (merger_param_name in trainable_params) is False
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_visual_model_save_load():
# check VLM's state dict: https://github.com/huggingface/transformers/pull/38385
model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct")

View File

@@ -30,14 +30,12 @@ INFER_ARGS = {
}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_base():
model = load_infer_model(**INFER_ARGS)
ref_model = load_reference_model(TINY_LLAMA3)
compare_model(model, ref_model)
@pytest.mark.runs_on(["cpu"])
@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
def test_valuehead():
model = load_infer_model(add_valuehead=True, **INFER_ARGS)

View File

@@ -14,7 +14,6 @@
import os
import pytest
import torch
from llamafactory.train.test_utils import load_infer_model, load_train_model
@@ -44,7 +43,6 @@ INFER_ARGS = {
}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_freeze_train_all_modules():
model = load_train_model(freeze_trainable_layers=1, **TRAIN_ARGS)
for name, param in model.named_parameters():
@@ -56,7 +54,6 @@ def test_freeze_train_all_modules():
assert param.dtype == torch.float16
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_freeze_train_extra_modules():
model = load_train_model(freeze_trainable_layers=1, freeze_extra_modules="embed_tokens,lm_head", **TRAIN_ARGS)
for name, param in model.named_parameters():
@@ -68,7 +65,6 @@ def test_freeze_train_extra_modules():
assert param.dtype == torch.float16
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_freeze_inference():
model = load_infer_model(**INFER_ARGS)
for param in model.parameters():

View File

@@ -14,7 +14,6 @@
import os
import pytest
import torch
from llamafactory.train.test_utils import load_infer_model, load_train_model
@@ -44,7 +43,6 @@ INFER_ARGS = {
}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_full_train():
model = load_train_model(**TRAIN_ARGS)
for param in model.parameters():
@@ -52,7 +50,6 @@ def test_full_train():
assert param.dtype == torch.float32
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_full_inference():
model = load_infer_model(**INFER_ARGS)
for param in model.parameters():

View File

@@ -55,35 +55,30 @@ INFER_ARGS = {
}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_lora_train_qv_modules():
model = load_train_model(lora_target="q_proj,v_proj", **TRAIN_ARGS)
linear_modules, _ = check_lora_model(model)
assert linear_modules == {"q_proj", "v_proj"}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_lora_train_all_modules():
model = load_train_model(lora_target="all", **TRAIN_ARGS)
linear_modules, _ = check_lora_model(model)
assert linear_modules == {"q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_lora_train_extra_modules():
model = load_train_model(additional_target="embed_tokens,lm_head", **TRAIN_ARGS)
_, extra_modules = check_lora_model(model)
assert extra_modules == {"embed_tokens", "lm_head"}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_lora_train_old_adapters():
model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=False, **TRAIN_ARGS)
ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
compare_model(model, ref_model)
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_lora_train_new_adapters():
model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=True, **TRAIN_ARGS)
ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
@@ -92,7 +87,6 @@ def test_lora_train_new_adapters():
)
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
def test_lora_train_valuehead():
model = load_train_model(add_valuehead=True, **TRAIN_ARGS)
@@ -103,7 +97,6 @@ def test_lora_train_valuehead():
assert torch.allclose(state_dict["v_head.summary.bias"], ref_state_dict["v_head.summary.bias"])
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_lora_inference():
model = load_infer_model(**INFER_ARGS)
ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True).merge_and_unload()

View File

@@ -49,7 +49,6 @@ INFER_ARGS = {
}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.xfail(reason="PiSSA initialization is not stable in different platform.")
def test_pissa_train():
model = load_train_model(**TRAIN_ARGS)
@@ -57,7 +56,6 @@ def test_pissa_train():
compare_model(model, ref_model)
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.xfail(reason="Known connection error.")
def test_pissa_inference():
model = load_infer_model(**INFER_ARGS)

View File

@@ -59,7 +59,6 @@ class DataCollatorWithVerbose(DataCollatorWithPadding):
return {k: v[:, :1] for k, v in batch.items()} # truncate input length
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.parametrize("disable_shuffling", [False, True])
def test_shuffle(disable_shuffling: bool):
model_args, data_args, training_args, finetuning_args, _ = get_train_args(

View File

@@ -1,18 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
runs_on = pytest.mark.runs_on

View File

@@ -1,2 +1,2 @@
# change if test fails or cache is outdated
0.9.4.104
0.9.4.105