mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 20:00:36 +08:00
[test] add npu test yaml and add ascend a3 docker file (#9547)
Co-authored-by: jiaqiw09 <jiaqiw960714@gmail.com>
This commit is contained in:
@@ -42,6 +42,7 @@ TRAIN_ARGS = {
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.parametrize("num_samples", [16])
|
||||
def test_supervised_single_turn(num_samples: int):
|
||||
train_dataset = load_dataset_module(dataset_dir="ONLINE", dataset=TINY_DATA, **TRAIN_ARGS)["train_dataset"]
|
||||
@@ -61,6 +62,7 @@ def test_supervised_single_turn(num_samples: int):
|
||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.parametrize("num_samples", [8])
|
||||
def test_supervised_multi_turn(num_samples: int):
|
||||
train_dataset = load_dataset_module(dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", **TRAIN_ARGS)[
|
||||
@@ -74,6 +76,7 @@ def test_supervised_multi_turn(num_samples: int):
|
||||
assert train_dataset["input_ids"][index] == ref_input_ids
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.parametrize("num_samples", [4])
|
||||
def test_supervised_train_on_prompt(num_samples: int):
|
||||
train_dataset = load_dataset_module(
|
||||
@@ -88,6 +91,7 @@ def test_supervised_train_on_prompt(num_samples: int):
|
||||
assert train_dataset["labels"][index] == ref_ids
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.parametrize("num_samples", [4])
|
||||
def test_supervised_mask_history(num_samples: int):
|
||||
train_dataset = load_dataset_module(
|
||||
|
||||
Reference in New Issue
Block a user