[misc] fix accelerator (#9661)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Yaowei Zheng
2025-12-25 02:11:04 +08:00
committed by GitHub
parent 6a2eafbae3
commit a754604c11
44 changed files with 396 additions and 448 deletions

View File

@@ -55,35 +55,30 @@ INFER_ARGS = {
}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_lora_train_qv_modules():
model = load_train_model(lora_target="q_proj,v_proj", **TRAIN_ARGS)
linear_modules, _ = check_lora_model(model)
assert linear_modules == {"q_proj", "v_proj"}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_lora_train_all_modules():
model = load_train_model(lora_target="all", **TRAIN_ARGS)
linear_modules, _ = check_lora_model(model)
assert linear_modules == {"q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_lora_train_extra_modules():
model = load_train_model(additional_target="embed_tokens,lm_head", **TRAIN_ARGS)
_, extra_modules = check_lora_model(model)
assert extra_modules == {"embed_tokens", "lm_head"}
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_lora_train_old_adapters():
model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=False, **TRAIN_ARGS)
ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
compare_model(model, ref_model)
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_lora_train_new_adapters():
model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=True, **TRAIN_ARGS)
ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
@@ -92,7 +87,6 @@ def test_lora_train_new_adapters():
)
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
def test_lora_train_valuehead():
model = load_train_model(add_valuehead=True, **TRAIN_ARGS)
@@ -103,7 +97,6 @@ def test_lora_train_valuehead():
assert torch.allclose(state_dict["v_head.summary.bias"], ref_state_dict["v_head.summary.bias"])
@pytest.mark.runs_on(["cpu", "npu", "cuda"])
def test_lora_inference():
model = load_infer_model(**INFER_ARGS)
ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True).merge_and_unload()