[v1] add sft (#9752)

This commit is contained in:
Yaowei Zheng
2026-01-12 03:15:01 +08:00
committed by GitHub
parent 4d3621e3d3
commit 958b9c3468
29 changed files with 439 additions and 305 deletions

View File

@@ -34,7 +34,7 @@ def test_get_args_from_yaml(tmp_path: Path):
quant_config: null
### data
dataset: llamafactory/v1-sft-demo
train_dataset: llamafactory/v1-sft-demo
### training
output_dir: outputs/test_run
@@ -56,8 +56,8 @@ def test_get_args_from_yaml(tmp_path: Path):
test_argv = ["test_args_parser.py", str(config_file)]
with patch.object(sys, "argv", test_argv):
data_args, model_args, training_args, sample_args = get_args()
assert data_args.dataset == "llamafactory/v1-sft-demo"
model_args, data_args, training_args, sample_args = get_args()
assert data_args.train_dataset == "llamafactory/v1-sft-demo"
assert model_args.model == "llamafactory/tiny-random-qwen3"
assert model_args.kernel_config.name == "auto"
assert model_args.kernel_config.get("include_kernels") == "auto"

View File

@@ -23,8 +23,8 @@ from llamafactory.v1.core.data_engine import DataEngine
@pytest.mark.parametrize("num_samples", [16])
def test_map_dataset(num_samples: int):
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
data_engine = DataEngine(data_args)
data_args = DataArguments(train_dataset="llamafactory/v1-sft-demo")
data_engine = DataEngine(data_args.train_dataset)
original_data = load_dataset("llamafactory/v1-sft-demo", split="train")
indexes = random.choices(range(len(data_engine)), k=num_samples)
for index in indexes:

View File

@@ -19,8 +19,8 @@ from llamafactory.v1.core.utils.batching import BatchGenerator
def test_normal_batching():
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
data_engine = DataEngine(data_args=data_args)
data_args = DataArguments(train_dataset="llamafactory/v1-sft-demo")
data_engine = DataEngine(data_args.train_dataset)
model_args = ModelArguments(model="llamafactory/tiny-random-qwen3")
model_engine = ModelEngine(model_args=model_args)
training_args = TrainingArguments(

View File

@@ -111,8 +111,8 @@ def test_chatml_parse():
def test_chatml_rendering_remote(num_samples: int):
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
renderer = Renderer(template="chatml", processor=tokenizer)
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
data_engine = DataEngine(data_args)
data_args = DataArguments(train_dataset="llamafactory/v1-sft-demo")
data_engine = DataEngine(data_args.train_dataset)
for index in range(num_samples):
v1_inputs = renderer.render_messages(data_engine[index]["messages"], is_generate=True)
prefix = tokenizer.encode("<|im_start|>user\n", add_special_tokens=False)
@@ -167,8 +167,8 @@ def test_qwen3_nothink_parse():
def test_qwen3_nothink_rendering_remote(num_samples: int):
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
data_args = DataArguments(dataset="llamafactory/reason-tool-use-demo-1500")
data_engine = DataEngine(data_args)
data_args = DataArguments(train_dataset="llamafactory/reason-tool-use-demo-1500")
data_engine = DataEngine(data_args.train_dataset)
for index in range(num_samples):
v1_inputs = renderer.render_messages(data_engine[index]["messages"], tools=data_engine[index]["tools"])
prefix_text = (
@@ -213,7 +213,7 @@ def test_process_dpo_samples():
model_inputs = renderer.process_samples(samples)
assert len(model_inputs) == 1
assert model_inputs[0]["input_ids"] == hf_inputs * 2
assert model_inputs[0]["token_type_ids"] == [0] * len(hf_inputs) + [1] * len(hf_inputs)
assert model_inputs[0]["token_type_ids"] == [1] * len(hf_inputs) + [2] * len(hf_inputs)
assert model_inputs[0]["extra_info"] == "test"
assert model_inputs[0]["_dataset_name"] == "default"

View File

@@ -24,8 +24,8 @@ from llamafactory.v1.plugins.data_plugins.converter import DataConverterPlugin
@pytest.mark.parametrize("num_samples", [16])
def test_alpaca_converter(num_samples: int):
data_args = DataArguments(dataset="llamafactory/v1-dataset-info/tiny-supervised-dataset.yaml")
data_engine = DataEngine(data_args)
data_args = DataArguments(train_dataset="llamafactory/v1-dataset-info/tiny-supervised-dataset.yaml")
data_engine = DataEngine(data_args.train_dataset)
original_data = load_dataset("llamafactory/tiny-supervised-dataset", split="train")
indexes = random.choices(range(len(data_engine)), k=num_samples)
for index in indexes:
@@ -73,8 +73,8 @@ def test_sharegpt_converter():
@pytest.mark.parametrize("num_samples", [16])
def test_pair_converter(num_samples: int):
data_args = DataArguments(dataset="llamafactory/v1-dataset-info/orca-dpo-pairs.yaml")
data_engine = DataEngine(data_args)
data_args = DataArguments(train_dataset="llamafactory/v1-dataset-info/orca-dpo-pairs.yaml")
data_engine = DataEngine(data_args.train_dataset)
original_data = load_dataset("HuggingFaceH4/orca_dpo_pairs", split="train_prefs")
indexes = random.choices(range(len(data_engine)), k=num_samples)
for index in indexes:

View File

@@ -19,7 +19,7 @@ from llamafactory.v1.core.model_engine import ModelEngine
def test_init_on_meta():
_, model_args, *_ = get_args(
model_args, *_ = get_args(
dict(
model="llamafactory/tiny-random-qwen3",
init_config={"name": "init_on_meta"},
@@ -30,7 +30,7 @@ def test_init_on_meta():
def test_init_on_rank0():
_, model_args, *_ = get_args(
model_args, *_ = get_args(
dict(
model="llamafactory/tiny-random-qwen3",
init_config={"name": "init_on_rank0"},
@@ -44,7 +44,7 @@ def test_init_on_rank0():
def test_init_on_default():
_, model_args, *_ = get_args(
model_args, *_ = get_args(
dict(
model="llamafactory/tiny-random-qwen3",
init_config={"name": "init_on_default"},

View File

@@ -43,7 +43,8 @@ def test_apply_kernel(mock_get_accelerator: MagicMock):
reload_kernels()
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen3")
# NOTE: use a special model to avoid contamination by other tests
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward
model = apply_default_kernels(model=model, include_kernels="npu_fused_rmsnorm")
@@ -62,7 +63,8 @@ def test_apply_all_kernels(mock_get_accelerator: MagicMock):
reload_kernels()
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen3")
# NOTE: use a special model to avoid contamination by other tests
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward