mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-13 17:40:34 +08:00
[v1] add sft (#9752)
This commit is contained in:
@@ -19,8 +19,8 @@ from llamafactory.v1.core.utils.batching import BatchGenerator
|
||||
|
||||
|
||||
def test_normal_batching():
|
||||
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
|
||||
data_engine = DataEngine(data_args=data_args)
|
||||
data_args = DataArguments(train_dataset="llamafactory/v1-sft-demo")
|
||||
data_engine = DataEngine(data_args.train_dataset)
|
||||
model_args = ModelArguments(model="llamafactory/tiny-random-qwen3")
|
||||
model_engine = ModelEngine(model_args=model_args)
|
||||
training_args = TrainingArguments(
|
||||
|
||||
@@ -111,8 +111,8 @@ def test_chatml_parse():
|
||||
def test_chatml_rendering_remote(num_samples: int):
|
||||
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||
renderer = Renderer(template="chatml", processor=tokenizer)
|
||||
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
|
||||
data_engine = DataEngine(data_args)
|
||||
data_args = DataArguments(train_dataset="llamafactory/v1-sft-demo")
|
||||
data_engine = DataEngine(data_args.train_dataset)
|
||||
for index in range(num_samples):
|
||||
v1_inputs = renderer.render_messages(data_engine[index]["messages"], is_generate=True)
|
||||
prefix = tokenizer.encode("<|im_start|>user\n", add_special_tokens=False)
|
||||
@@ -167,8 +167,8 @@ def test_qwen3_nothink_parse():
|
||||
def test_qwen3_nothink_rendering_remote(num_samples: int):
|
||||
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
|
||||
renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
|
||||
data_args = DataArguments(dataset="llamafactory/reason-tool-use-demo-1500")
|
||||
data_engine = DataEngine(data_args)
|
||||
data_args = DataArguments(train_dataset="llamafactory/reason-tool-use-demo-1500")
|
||||
data_engine = DataEngine(data_args.train_dataset)
|
||||
for index in range(num_samples):
|
||||
v1_inputs = renderer.render_messages(data_engine[index]["messages"], tools=data_engine[index]["tools"])
|
||||
prefix_text = (
|
||||
@@ -213,7 +213,7 @@ def test_process_dpo_samples():
|
||||
model_inputs = renderer.process_samples(samples)
|
||||
assert len(model_inputs) == 1
|
||||
assert model_inputs[0]["input_ids"] == hf_inputs * 2
|
||||
assert model_inputs[0]["token_type_ids"] == [0] * len(hf_inputs) + [1] * len(hf_inputs)
|
||||
assert model_inputs[0]["token_type_ids"] == [1] * len(hf_inputs) + [2] * len(hf_inputs)
|
||||
assert model_inputs[0]["extra_info"] == "test"
|
||||
assert model_inputs[0]["_dataset_name"] == "default"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user