diff --git a/src/llamafactory/data/data_utils.py b/src/llamafactory/data/data_utils.py index 9c2f527c..fd050f91 100644 --- a/src/llamafactory/data/data_utils.py +++ b/src/llamafactory/data/data_utils.py @@ -43,7 +43,7 @@ class Role(str, Enum): class DatasetModule(TypedDict): train_dataset: Optional[Union["Dataset", "IterableDataset"]] - eval_dataset: Optional[Union["Dataset", "IterableDataset"]] + eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]] def merge_dataset( @@ -54,11 +54,13 @@ def merge_dataset( """ if len(all_datasets) == 1: return all_datasets[0] + elif data_args.mix_strategy == "concat": if data_args.streaming: logger.warning_rank0_once("The samples between different datasets will not be mixed in streaming mode.") return concatenate_datasets(all_datasets) + elif data_args.mix_strategy.startswith("interleave"): if not data_args.streaming: logger.warning_rank0_once("We recommend using `mix_strategy=concat` in non-streaming mode.") @@ -69,24 +71,75 @@ def merge_dataset( seed=seed, stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted", ) + else: raise ValueError(f"Unknown mixing strategy: {data_args.mix_strategy}.") def split_dataset( - dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int + dataset: Optional[Union["Dataset", "IterableDataset"]], + eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]], + data_args: "DataArguments", + seed: int, ) -> "DatasetDict": r""" Splits the dataset and returns a dataset dict containing train set and validation set. Supports both map dataset and iterable dataset. """ - if data_args.streaming: - dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed) - val_set = dataset.take(int(data_args.val_size)) - train_set = dataset.skip(int(data_args.val_size)) - return DatasetDict({"train": train_set, "validation": val_set}) - else: - val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size - dataset = dataset.train_test_split(test_size=val_size, seed=seed) - return DatasetDict({"train": dataset["train"], "validation": dataset["test"]}) + if eval_dataset is not None and data_args.val_size > 1e-6: + raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.") + + dataset_dict = {} + if dataset is not None: + if data_args.streaming: + dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed) + + if data_args.val_size > 1e-6: + if data_args.streaming: + dataset_dict["validation"] = dataset.take(int(data_args.val_size)) + dataset_dict["train"] = dataset.skip(int(data_args.val_size)) + else: + val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size + dataset_dict = dataset.train_test_split(test_size=val_size, seed=seed) + dataset = dataset.train_test_split(test_size=val_size, seed=seed) + dataset_dict = {"train": dataset["train"], "validation": dataset["test"]} + else: + dataset_dict["train"] = dataset + + if eval_dataset is not None: + if isinstance(eval_dataset, dict): + dataset_dict.update({f"validation_{name}": data for name, data in eval_dataset.items()}) + else: + if data_args.streaming: + eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed) + + dataset_dict["validation"] = eval_dataset + + return DatasetDict(dataset_dict) + + +def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule": + r""" + Converts dataset or dataset dict to dataset module. + """ + dataset_module: "DatasetModule" = {} + if isinstance(dataset, DatasetDict): # dataset dict + if "train" in dataset: + dataset_module["train_dataset"] = dataset["train"] + + if "validation" in dataset: + dataset_module["eval_dataset"] = dataset["validation"] + else: + eval_dataset = {} + for key in dataset.keys(): + if key.startswith("validation_"): + eval_dataset[key[len("validation_") :]] = dataset[key] + + if len(eval_dataset): + dataset_module["eval_dataset"] = eval_dataset + + else: # single dataset + dataset_module["train_dataset"] = dataset + + return dataset_module diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 7782dd5b..f6cb955f 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -17,13 +17,13 @@ import sys from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union import numpy as np -from datasets import DatasetDict, load_dataset, load_from_disk +from datasets import load_dataset, load_from_disk from ..extras import logging from ..extras.constants import FILEEXT2TYPE from ..extras.misc import check_version, has_tokenized_data from .converter import align_dataset -from .data_utils import merge_dataset, split_dataset +from .data_utils import get_dataset_module, merge_dataset, split_dataset from .parser import get_dataset_list from .processor import ( FeedbackDatasetProcessor, @@ -292,23 +292,12 @@ def get_dataset( if data_args.tokenized_path is not None: if has_tokenized_data(data_args.tokenized_path): logger.warning_rank0("Loading dataset from disk will ignore other data arguments.") - tokenized_data: Union["Dataset", "DatasetDict"] = load_from_disk(data_args.tokenized_path) - logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.") - - dataset_module: Dict[str, "Dataset"] = {} - if isinstance(tokenized_data, DatasetDict): - if "train" in tokenized_data: - dataset_module["train_dataset"] = tokenized_data["train"] - - if "validation" in tokenized_data: - dataset_module["eval_dataset"] = tokenized_data["validation"] - - else: # single dataset - dataset_module["train_dataset"] = tokenized_data - + tokenized_data = load_from_disk(data_args.tokenized_path) + dataset_module = get_dataset_module(tokenized_data) if data_args.streaming: - dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()} + dataset_module["train_dataset"] = dataset_module["train_dataset"].to_iterable_dataset() + logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.") return dataset_module if data_args.streaming: @@ -335,27 +324,7 @@ def get_dataset( eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True ) - if data_args.val_size > 1e-6: - dataset_dict = split_dataset(dataset, data_args, seed=training_args.seed) - else: - dataset_dict = {} - if dataset is not None: - if data_args.streaming: - dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) - - dataset_dict["train"] = dataset - - if eval_dataset is not None: - if isinstance(eval_dataset, dict): - dataset_dict.update({f"validation_{name}": data for name, data in eval_dataset.items()}) - else: - if data_args.streaming: - eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) - - dataset_dict["validation"] = eval_dataset - - dataset_dict = DatasetDict(dataset_dict) - + dataset_dict = split_dataset(dataset, eval_dataset, data_args, seed=training_args.seed) if data_args.tokenized_path is not None: # save tokenized dataset to disk and exit if training_args.should_save: dataset_dict.save_to_disk(data_args.tokenized_path) @@ -364,19 +333,4 @@ def get_dataset( sys.exit(0) - dataset_module = {} - if "train" in dataset_dict: - dataset_module["train_dataset"] = dataset_dict["train"] - - if "validation" in dataset_dict: - dataset_module["eval_dataset"] = dataset_dict["validation"] - else: - eval_dataset = {} - for key in dataset_dict.keys(): - if key.startswith("validation_"): - eval_dataset[key[len("validation_") :]] = dataset_dict[key] - - if len(eval_dataset): - dataset_module["eval_dataset"] = eval_dataset - - return dataset_module + return get_dataset_module(dataset_dict) diff --git a/src/llamafactory/train/test_utils.py b/src/llamafactory/train/test_utils.py index 9404c249..f8ba3510 100644 --- a/src/llamafactory/train/test_utils.py +++ b/src/llamafactory/train/test_utils.py @@ -26,10 +26,11 @@ from ..model import load_model, load_tokenizer if TYPE_CHECKING: - from datasets import Dataset from peft import LoraModel from transformers import PreTrainedModel + from ..data.data_utils import DatasetModule + def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_keys: Sequence[str] = []) -> None: state_dict_a = model_a.state_dict() @@ -101,12 +102,12 @@ def load_reference_model( return model -def load_train_dataset(**kwargs) -> "Dataset": +def load_dataset_module(**kwargs) -> "DatasetModule": model_args, data_args, training_args, _, _ = get_train_args(kwargs) tokenizer_module = load_tokenizer(model_args) template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args) dataset_module = get_dataset(template, model_args, data_args, training_args, kwargs["stage"], **tokenizer_module) - return dataset_module["train_dataset"] + return dataset_module def patch_valuehead_model() -> None: diff --git a/tests/data/processor/test_feedback.py b/tests/data/processor/test_feedback.py index a70c6e1d..27c3676c 100644 --- a/tests/data/processor/test_feedback.py +++ b/tests/data/processor/test_feedback.py @@ -20,7 +20,7 @@ from datasets import load_dataset from transformers import AutoTokenizer from llamafactory.extras.constants import IGNORE_INDEX -from llamafactory.train.test_utils import load_train_dataset +from llamafactory.train.test_utils import load_dataset_module DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data") @@ -36,7 +36,6 @@ TRAIN_ARGS = { "dataset_dir": "REMOTE:" + DEMO_DATA, "template": "llama3", "cutoff_len": 8192, - "overwrite_cache": True, "output_dir": "dummy_dir", "overwrite_output_dir": True, "fp16": True, @@ -45,7 +44,7 @@ TRAIN_ARGS = { @pytest.mark.parametrize("num_samples", [16]) def test_feedback_data(num_samples: int): - train_dataset = load_train_dataset(**TRAIN_ARGS) + train_dataset = load_dataset_module(**TRAIN_ARGS)["train_dataset"] ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) original_data = load_dataset(DEMO_DATA, name="kto_en_demo", split="train") indexes = random.choices(range(len(original_data)), k=num_samples) diff --git a/tests/data/processor/test_pairwise.py b/tests/data/processor/test_pairwise.py index 7602d070..3faac9a7 100644 --- a/tests/data/processor/test_pairwise.py +++ b/tests/data/processor/test_pairwise.py @@ -21,7 +21,7 @@ from datasets import load_dataset from transformers import AutoTokenizer from llamafactory.extras.constants import IGNORE_INDEX -from llamafactory.train.test_utils import load_train_dataset +from llamafactory.train.test_utils import load_dataset_module DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data") @@ -37,7 +37,6 @@ TRAIN_ARGS = { "dataset_dir": "REMOTE:" + DEMO_DATA, "template": "llama3", "cutoff_len": 8192, - "overwrite_cache": True, "output_dir": "dummy_dir", "overwrite_output_dir": True, "fp16": True, @@ -55,7 +54,7 @@ def _convert_sharegpt_to_openai(messages: List[Dict[str, str]]) -> List[Dict[str @pytest.mark.parametrize("num_samples", [16]) def test_pairwise_data(num_samples: int): - train_dataset = load_train_dataset(**TRAIN_ARGS) + train_dataset = load_dataset_module(**TRAIN_ARGS)["train_dataset"] ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) original_data = load_dataset(DEMO_DATA, name="dpo_en_demo", split="train") indexes = random.choices(range(len(original_data)), k=num_samples) diff --git a/tests/data/processor/test_supervised.py b/tests/data/processor/test_supervised.py index 2a988e84..e2171721 100644 --- a/tests/data/processor/test_supervised.py +++ b/tests/data/processor/test_supervised.py @@ -20,7 +20,7 @@ from datasets import load_dataset from transformers import AutoTokenizer from llamafactory.extras.constants import IGNORE_INDEX -from llamafactory.train.test_utils import load_train_dataset +from llamafactory.train.test_utils import load_dataset_module DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data") @@ -36,7 +36,6 @@ TRAIN_ARGS = { "finetuning_type": "full", "template": "llama3", "cutoff_len": 8192, - "overwrite_cache": True, "output_dir": "dummy_dir", "overwrite_output_dir": True, "fp16": True, @@ -45,7 +44,7 @@ TRAIN_ARGS = { @pytest.mark.parametrize("num_samples", [16]) def test_supervised_single_turn(num_samples: int): - train_dataset = load_train_dataset(dataset_dir="ONLINE", dataset=TINY_DATA, **TRAIN_ARGS) + train_dataset = load_dataset_module(dataset_dir="ONLINE", dataset=TINY_DATA, **TRAIN_ARGS)["train_dataset"] ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) original_data = load_dataset(TINY_DATA, split="train") indexes = random.choices(range(len(original_data)), k=num_samples) @@ -64,7 +63,9 @@ def test_supervised_single_turn(num_samples: int): @pytest.mark.parametrize("num_samples", [8]) def test_supervised_multi_turn(num_samples: int): - train_dataset = load_train_dataset(dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", **TRAIN_ARGS) + train_dataset = load_dataset_module(dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", **TRAIN_ARGS)[ + "train_dataset" + ] ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) original_data = load_dataset(DEMO_DATA, name="system_chat", split="train") indexes = random.choices(range(len(original_data)), k=num_samples) @@ -75,9 +76,9 @@ def test_supervised_multi_turn(num_samples: int): @pytest.mark.parametrize("num_samples", [4]) def test_supervised_train_on_prompt(num_samples: int): - train_dataset = load_train_dataset( + train_dataset = load_dataset_module( dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", train_on_prompt=True, **TRAIN_ARGS - ) + )["train_dataset"] ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) original_data = load_dataset(DEMO_DATA, name="system_chat", split="train") indexes = random.choices(range(len(original_data)), k=num_samples) @@ -89,9 +90,9 @@ def test_supervised_train_on_prompt(num_samples: int): @pytest.mark.parametrize("num_samples", [4]) def test_supervised_mask_history(num_samples: int): - train_dataset = load_train_dataset( + train_dataset = load_dataset_module( dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", mask_history=True, **TRAIN_ARGS - ) + )["train_dataset"] ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) original_data = load_dataset(DEMO_DATA, name="system_chat", split="train") indexes = random.choices(range(len(original_data)), k=num_samples) diff --git a/tests/data/processor/test_unsupervised.py b/tests/data/processor/test_unsupervised.py index c3f3159f..4b0a97d3 100644 --- a/tests/data/processor/test_unsupervised.py +++ b/tests/data/processor/test_unsupervised.py @@ -19,7 +19,7 @@ import pytest from datasets import load_dataset from transformers import AutoTokenizer -from llamafactory.train.test_utils import load_train_dataset +from llamafactory.train.test_utils import load_dataset_module DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data") @@ -39,7 +39,6 @@ TRAIN_ARGS = { "dataset_dir": "REMOTE:" + DEMO_DATA, "template": "llama3", "cutoff_len": 8192, - "overwrite_cache": True, "output_dir": "dummy_dir", "overwrite_output_dir": True, "fp16": True, @@ -48,7 +47,7 @@ TRAIN_ARGS = { @pytest.mark.parametrize("num_samples", [16]) def test_unsupervised_data(num_samples: int): - train_dataset = load_train_dataset(**TRAIN_ARGS) + train_dataset = load_dataset_module(**TRAIN_ARGS)["train_dataset"] ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) original_data = load_dataset(DEMO_DATA, name="system_chat", split="train") indexes = random.choices(range(len(original_data)), k=num_samples) diff --git a/tests/data/test_converter.py b/tests/data/test_converter.py index 0308d3ee..6997f75f 100644 --- a/tests/data/test_converter.py +++ b/tests/data/test_converter.py @@ -1,3 +1,17 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from llamafactory.data import Role from llamafactory.data.converter import get_dataset_converter from llamafactory.data.parser import DatasetAttr diff --git a/tests/data/test_loader.py b/tests/data/test_loader.py new file mode 100644 index 00000000..fc2d2a91 --- /dev/null +++ b/tests/data/test_loader.py @@ -0,0 +1,56 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from llamafactory.train.test_utils import load_dataset_module + + +DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data") + +TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") + +TINY_DATA = os.getenv("TINY_DATA", "llamafactory/tiny-supervised-dataset") + +TRAIN_ARGS = { + "model_name_or_path": TINY_LLAMA, + "stage": "sft", + "do_train": True, + "finetuning_type": "full", + "template": "llama3", + "dataset": TINY_DATA, + "dataset_dir": "ONLINE", + "cutoff_len": 8192, + "output_dir": "dummy_dir", + "overwrite_output_dir": True, + "fp16": True, +} + + +def test_load_train_only(): + dataset_module = load_dataset_module(**TRAIN_ARGS) + assert dataset_module.get("train_dataset") is not None + assert dataset_module.get("eval_dataset") is None + + +def test_load_val_size(): + dataset_module = load_dataset_module(val_size=0.1, **TRAIN_ARGS) + assert dataset_module.get("train_dataset") is not None + assert dataset_module.get("eval_dataset") is not None + + +def test_load_eval_data(): + dataset_module = load_dataset_module(eval_dataset=TINY_DATA, **TRAIN_ARGS) + assert dataset_module.get("train_dataset") is not None + assert dataset_module.get("eval_dataset") is not None diff --git a/tests/e2e/test_train.py b/tests/e2e/test_train.py index 46d7813c..f16b3522 100644 --- a/tests/e2e/test_train.py +++ b/tests/e2e/test_train.py @@ -32,7 +32,6 @@ TRAIN_ARGS = { "dataset_dir": "REMOTE:" + DEMO_DATA, "template": "llama3", "cutoff_len": 1, - "overwrite_cache": False, "overwrite_output_dir": True, "per_device_train_batch_size": 1, "max_steps": 1, diff --git a/tests/model/model_utils/test_checkpointing.py b/tests/model/model_utils/test_checkpointing.py index ef38d0d5..f0246016 100644 --- a/tests/model/model_utils/test_checkpointing.py +++ b/tests/model/model_utils/test_checkpointing.py @@ -33,7 +33,6 @@ TRAIN_ARGS = { "dataset_dir": "ONLINE", "template": "llama3", "cutoff_len": 1024, - "overwrite_cache": True, "output_dir": "dummy_dir", "overwrite_output_dir": True, "fp16": True, diff --git a/tests/model/test_freeze.py b/tests/model/test_freeze.py index cd9fc61c..97200852 100644 --- a/tests/model/test_freeze.py +++ b/tests/model/test_freeze.py @@ -30,7 +30,6 @@ TRAIN_ARGS = { "dataset_dir": "ONLINE", "template": "llama3", "cutoff_len": 1024, - "overwrite_cache": True, "output_dir": "dummy_dir", "overwrite_output_dir": True, "fp16": True, diff --git a/tests/model/test_full.py b/tests/model/test_full.py index 3bd9c9e8..8aff2223 100644 --- a/tests/model/test_full.py +++ b/tests/model/test_full.py @@ -30,7 +30,6 @@ TRAIN_ARGS = { "dataset_dir": "ONLINE", "template": "llama3", "cutoff_len": 1024, - "overwrite_cache": True, "output_dir": "dummy_dir", "overwrite_output_dir": True, "fp16": True, diff --git a/tests/model/test_lora.py b/tests/model/test_lora.py index 2a4177ce..1cda7bb7 100644 --- a/tests/model/test_lora.py +++ b/tests/model/test_lora.py @@ -42,7 +42,6 @@ TRAIN_ARGS = { "dataset_dir": "ONLINE", "template": "llama3", "cutoff_len": 1024, - "overwrite_cache": True, "output_dir": "dummy_dir", "overwrite_output_dir": True, "fp16": True, diff --git a/tests/model/test_pissa.py b/tests/model/test_pissa.py index d1e4114c..875a3bf4 100644 --- a/tests/model/test_pissa.py +++ b/tests/model/test_pissa.py @@ -34,7 +34,6 @@ TRAIN_ARGS = { "dataset_dir": "ONLINE", "template": "llama3", "cutoff_len": 1024, - "overwrite_cache": True, "output_dir": "dummy_dir", "overwrite_output_dir": True, "fp16": True, diff --git a/tests/train/test_sft_trainer.py b/tests/train/test_sft_trainer.py index bb67a31e..1f84071e 100644 --- a/tests/train/test_sft_trainer.py +++ b/tests/train/test_sft_trainer.py @@ -38,7 +38,6 @@ TRAIN_ARGS = { "dataset_dir": "ONLINE", "template": "llama3", "cutoff_len": 1024, - "overwrite_cache": False, "overwrite_output_dir": True, "per_device_train_batch_size": 1, "max_steps": 1,