mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-13 09:30:34 +08:00
[v1] add sft (#9752)
This commit is contained in:
@@ -24,8 +24,8 @@ from llamafactory.v1.plugins.data_plugins.converter import DataConverterPlugin
|
||||
|
||||
@pytest.mark.parametrize("num_samples", [16])
|
||||
def test_alpaca_converter(num_samples: int):
|
||||
data_args = DataArguments(dataset="llamafactory/v1-dataset-info/tiny-supervised-dataset.yaml")
|
||||
data_engine = DataEngine(data_args)
|
||||
data_args = DataArguments(train_dataset="llamafactory/v1-dataset-info/tiny-supervised-dataset.yaml")
|
||||
data_engine = DataEngine(data_args.train_dataset)
|
||||
original_data = load_dataset("llamafactory/tiny-supervised-dataset", split="train")
|
||||
indexes = random.choices(range(len(data_engine)), k=num_samples)
|
||||
for index in indexes:
|
||||
@@ -73,8 +73,8 @@ def test_sharegpt_converter():
|
||||
|
||||
@pytest.mark.parametrize("num_samples", [16])
|
||||
def test_pair_converter(num_samples: int):
|
||||
data_args = DataArguments(dataset="llamafactory/v1-dataset-info/orca-dpo-pairs.yaml")
|
||||
data_engine = DataEngine(data_args)
|
||||
data_args = DataArguments(train_dataset="llamafactory/v1-dataset-info/orca-dpo-pairs.yaml")
|
||||
data_engine = DataEngine(data_args.train_dataset)
|
||||
original_data = load_dataset("HuggingFaceH4/orca_dpo_pairs", split="train_prefs")
|
||||
indexes = random.choices(range(len(data_engine)), k=num_samples)
|
||||
for index in indexes:
|
||||
|
||||
@@ -19,7 +19,7 @@ from llamafactory.v1.core.model_engine import ModelEngine
|
||||
|
||||
|
||||
def test_init_on_meta():
|
||||
_, model_args, *_ = get_args(
|
||||
model_args, *_ = get_args(
|
||||
dict(
|
||||
model="llamafactory/tiny-random-qwen3",
|
||||
init_config={"name": "init_on_meta"},
|
||||
@@ -30,7 +30,7 @@ def test_init_on_meta():
|
||||
|
||||
|
||||
def test_init_on_rank0():
|
||||
_, model_args, *_ = get_args(
|
||||
model_args, *_ = get_args(
|
||||
dict(
|
||||
model="llamafactory/tiny-random-qwen3",
|
||||
init_config={"name": "init_on_rank0"},
|
||||
@@ -44,7 +44,7 @@ def test_init_on_rank0():
|
||||
|
||||
|
||||
def test_init_on_default():
|
||||
_, model_args, *_ = get_args(
|
||||
model_args, *_ = get_args(
|
||||
dict(
|
||||
model="llamafactory/tiny-random-qwen3",
|
||||
init_config={"name": "init_on_default"},
|
||||
|
||||
@@ -43,7 +43,8 @@ def test_apply_kernel(mock_get_accelerator: MagicMock):
|
||||
reload_kernels()
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||
# NOTE: use a special model to avoid contamination by other tests
|
||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||
model = apply_default_kernels(model=model, include_kernels="npu_fused_rmsnorm")
|
||||
@@ -62,7 +63,8 @@ def test_apply_all_kernels(mock_get_accelerator: MagicMock):
|
||||
reload_kernels()
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||
# NOTE: use a special model to avoid contamination by other tests
|
||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||
|
||||
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||
|
||||
Reference in New Issue
Block a user