From 74f0d02eb835c148383dcde504395fd0399f4cf2 Mon Sep 17 00:00:00 2001 From: codingma Date: Fri, 5 Jul 2024 15:52:10 +0800 Subject: [PATCH 1/8] 1. add custom eval dataset support 2. merge load dataset and split dataset function Former-commit-id: 76f3bbcfc0e11aa41f8f5cbebc60b77b987f7901 --- data/README.md | 3 +- data/README_zh.md | 1 + data/dataset_info.json | 12 ++++- scripts/cal_lr.py | 4 +- scripts/cal_ppl.py | 4 +- scripts/length_cdf.py | 6 +-- src/llamafactory/data/loader.py | 65 ++++++++++++++++++++------ src/llamafactory/data/parser.py | 9 ++-- src/llamafactory/hparams/data_args.py | 9 ++++ src/llamafactory/train/dpo/workflow.py | 4 +- src/llamafactory/train/kto/workflow.py | 4 +- src/llamafactory/train/ppo/workflow.py | 4 +- src/llamafactory/train/pt/workflow.py | 4 +- src/llamafactory/train/rm/workflow.py | 6 +-- src/llamafactory/train/sft/workflow.py | 8 ++-- tests/data/test_supervised.py | 4 +- 16 files changed, 104 insertions(+), 43 deletions(-) diff --git a/data/README.md b/data/README.md index 5ceae666..0f14bef8 100644 --- a/data/README.md +++ b/data/README.md @@ -12,7 +12,8 @@ Currently we support datasets in **alpaca** and **sharegpt** format. "ranking": "whether the dataset is a preference dataset or not. (default: False)", "subset": "the name of the subset. (optional, default: None)", "folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)", - "num_samples": "the number of samples in the dataset used for training. (optional, default: None)", + "num_samples": "the number of samples in the dataset used for training. (optional, default: None)", + "split": "which dataset split to use for training and evaluation (optional, default: train)", "columns (optional)": { "prompt": "the column name in the dataset containing the prompts. (default: instruction)", "query": "the column name in the dataset containing the queries. (default: input)", diff --git a/data/README_zh.md b/data/README_zh.md index 1795f352..7bf4fdba 100644 --- a/data/README_zh.md +++ b/data/README_zh.md @@ -13,6 +13,7 @@ "subset": "数据集子集的名称(可选,默认:None)", "folder": "Hugging Face 仓库的文件夹名称(可选,默认:None)", "num_samples": "该数据集中用于训练的样本数量。(可选,默认:None)", + "split": "数据集中的要使用的训练测试集切分(可选,默认:train)", "columns(可选)": { "prompt": "数据集代表提示词的表头名称(默认:instruction)", "query": "数据集代表请求的表头名称(默认:input)", diff --git a/data/dataset_info.json b/data/dataset_info.json index f8ffd407..e4b5a384 100644 --- a/data/dataset_info.json +++ b/data/dataset_info.json @@ -172,9 +172,19 @@ "deepctrl": { "ms_hub_url": "deepctrl/deepctrl-sft-data" }, - "adgen": { + "adgen_train": { "hf_hub_url": "HasturOfficial/adgen", "ms_hub_url": "AI-ModelScope/adgen", + "split": "train", + "columns": { + "prompt": "content", + "response": "summary" + } + }, + "adgen_val": { + "hf_hub_url": "HasturOfficial/adgen", + "ms_hub_url": "AI-ModelScope/adgen", + "split": "validation", "columns": { "prompt": "content", "response": "summary" diff --git a/scripts/cal_lr.py b/scripts/cal_lr.py index a103e082..a38f34e1 100644 --- a/scripts/cal_lr.py +++ b/scripts/cal_lr.py @@ -65,7 +65,7 @@ def calculate_lr( ) tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module) + dataset_module = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module) if stage == "pt": data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) elif stage == "sft": @@ -73,7 +73,7 @@ def calculate_lr( else: raise NotImplementedError("Stage does not supported: {}.".format(stage)) - dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True) + dataloader = DataLoader(dataset_module["eval_dataset"], batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True) valid_tokens, total_tokens = 0, 0 for batch in tqdm(dataloader): valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item() diff --git a/scripts/cal_ppl.py b/scripts/cal_ppl.py index 61f76922..3daa35ae 100644 --- a/scripts/cal_ppl.py +++ b/scripts/cal_ppl.py @@ -87,7 +87,7 @@ def cal_ppl( ) tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module) + dataset_module = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, is_trainable=False) if stage == "pt": data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) @@ -100,7 +100,7 @@ def cal_ppl( else: raise NotImplementedError("Stage does not supported: {}.".format(stage)) - dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True) + dataloader = DataLoader(dataset_module["eval_dataset"], batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True) criterion = torch.nn.CrossEntropyLoss(reduction="none") total_ppl = 0 perplexities = [] diff --git a/scripts/length_cdf.py b/scripts/length_cdf.py index 4cdf01e6..cef46416 100644 --- a/scripts/length_cdf.py +++ b/scripts/length_cdf.py @@ -47,10 +47,10 @@ def length_cdf( ) ) tokenizer_module = load_tokenizer(model_args) - trainset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) - total_num = len(trainset) + dataset_module = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) + total_num = len(dataset_module["eval_dataset"]) length_dict = defaultdict(int) - for sample in tqdm(trainset["input_ids"]): + for sample in tqdm(dataset_module["eval_dataset"]["input_ids"]): length_dict[len(sample) // interval * interval] += 1 length_tuples = list(length_dict.items()) diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 8e7062db..d527d7d2 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -15,7 +15,7 @@ import inspect import os import sys -from typing import TYPE_CHECKING, Literal, Optional, Union +from typing import TYPE_CHECKING, Literal, Optional, Union, Dict import numpy as np from datasets import load_dataset, load_from_disk @@ -24,10 +24,10 @@ from ..extras.constants import FILEEXT2TYPE from ..extras.logging import get_logger from ..extras.misc import has_tokenized_data from .aligner import align_dataset -from .data_utils import merge_dataset +from .data_utils import merge_dataset, split_dataset from .parser import get_dataset_list from .preprocess import get_preprocess_and_print_func -from .template import get_template_and_fix_tokenizer +from .template import get_template_and_fix_tokenizer, Template if TYPE_CHECKING: @@ -91,7 +91,7 @@ def load_single_dataset( subset_name=data_name, data_dir=data_dir, data_files=data_files, - split=data_args.split, + split=dataset_attr.split, cache_dir=cache_dir, token=model_args.ms_hub_token, use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")), @@ -111,7 +111,7 @@ def load_single_dataset( name=data_name, data_dir=data_dir, data_files=data_files, - split=data_args.split, + split=dataset_attr.split, cache_dir=model_args.cache_dir, token=model_args.hf_hub_token, streaming=(data_args.streaming and (dataset_attr.load_from != "file")), @@ -140,20 +140,17 @@ def load_single_dataset( return align_dataset(dataset, dataset_attr, data_args, training_args) -def get_dataset( +def load_and_preprocess( model_args: "ModelArguments", data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo", "kto"], tokenizer: "PreTrainedTokenizer", + template: "Template", processor: Optional["ProcessorMixin"] = None, + is_eval: bool = False ) -> Union["Dataset", "IterableDataset"]: - template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format) - if data_args.train_on_prompt and template.efficient_eos: - raise ValueError("Current template does not support `train_on_prompt`.") - - # Load tokenized dataset - if data_args.tokenized_path is not None: + if not is_eval and data_args.tokenized_path is not None: if has_tokenized_data(data_args.tokenized_path): logger.warning("Loading dataset from disk will ignore other data arguments.") dataset = load_from_disk(data_args.tokenized_path) @@ -165,9 +162,21 @@ def get_dataset( if data_args.streaming: raise ValueError("Turn off `streaming` when saving dataset to disk.") + if is_eval and data_args.eval_tokenized_path is not None: + if has_tokenized_data(data_args.eval_tokenized_path): + logger.warning("Loading dataset from disk will ignore other data arguments.") + dataset = load_from_disk(data_args.eval_tokenized_path) + logger.info("Loaded tokenized dataset from {}.".format(data_args.eval_tokenized_path)) + if data_args.streaming: + dataset = dataset.to_iterable_dataset() + return dataset + + if data_args.streaming: + raise ValueError("Turn off `streaming` when saving dataset to disk.") + with training_args.main_process_first(desc="load dataset"): all_datasets = [] - for dataset_attr in get_dataset_list(data_args): + for dataset_attr in get_dataset_list(data_args, data_args.eval_dataset if is_eval else data_args.dataset): if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True): raise ValueError("The dataset is not applicable in the current training stage.") @@ -190,13 +199,20 @@ def get_dataset( dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs) - if data_args.tokenized_path is not None: + if not is_eval and data_args.tokenized_path is not None: if training_args.should_save: dataset.save_to_disk(data_args.tokenized_path) logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path)) logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path)) sys.exit(0) + if is_eval and data_args.eval_tokenized_path is not None: + if training_args.should_save: + dataset.save_to_disk(data_args.eval_tokenized_path) + logger.info("Tokenized dataset saved at {}.".format(data_args.eval_tokenized_path)) + logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.eval_tokenized_path)) + + sys.exit(0) if training_args.should_log: try: @@ -208,3 +224,24 @@ def get_dataset( raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.") return dataset + + +def get_dataset( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + stage: Literal["pt", "sft", "rm", "ppo", "kto"], + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"] = None +) -> Dict[str, "Dataset"]: + template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format) + if data_args.train_on_prompt and template.efficient_eos: + raise ValueError("Current template does not support `train_on_prompt`.") + + train_dataset = load_and_preprocess(model_args, data_args, training_args, stage, tokenizer, template, processor) + + if data_args.eval_dataset or data_args.eval_tokenized_path: + eval_dataset = load_and_preprocess(model_args, data_args, training_args, stage, tokenizer, template, processor, True) + return {"train_dataset": train_dataset, "eval_dataset": eval_dataset} + else: + return split_dataset(train_dataset, data_args, training_args) diff --git a/src/llamafactory/data/parser.py b/src/llamafactory/data/parser.py index 5ae79774..c810ec8b 100644 --- a/src/llamafactory/data/parser.py +++ b/src/llamafactory/data/parser.py @@ -40,6 +40,7 @@ class DatasetAttr: subset: Optional[str] = None folder: Optional[str] = None num_samples: Optional[int] = None + split: Optional[str] = "train" # common columns system: Optional[str] = None tools: Optional[str] = None @@ -71,9 +72,9 @@ class DatasetAttr: setattr(self, key, obj.get(key, default)) -def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]: - if data_args.dataset is not None: - dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] +def get_dataset_list(data_args: "DataArguments", dataset: "str" = None) -> List["DatasetAttr"]: + if dataset is not None: + dataset_names = [ds.strip() for ds in dataset.split(",")] else: dataset_names = [] @@ -122,6 +123,8 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]: dataset_attr.set_attr("subset", dataset_info[name]) dataset_attr.set_attr("folder", dataset_info[name]) dataset_attr.set_attr("num_samples", dataset_info[name]) + if "split" in dataset_info[name]: + dataset_attr.set_attr("split", dataset_info[name]) if "columns" in dataset_info[name]: column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"] diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index a1025af7..7f7e62cd 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -33,6 +33,11 @@ class DataArguments: default=None, metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}, ) + eval_dataset: Optional[str] = field( + default=None, + metadata={"help": "The name of provided dataset(s) to use for eval during training. " + "Use commas to separate multiple datasets."}, + ) dataset_dir: str = field( default="data", metadata={"help": "Path to the folder containing the datasets."}, @@ -105,6 +110,10 @@ class DataArguments: default=None, metadata={"help": "Path to save or load the tokenized datasets."}, ) + eval_tokenized_path: Optional[str] = field( + default=None, + metadata={"help": "Path to save or load the tokenized eval datasets."}, + ) def __post_init__(self): if self.streaming and self.val_size > 1e-6 and self.val_size < 1: diff --git a/src/llamafactory/train/dpo/workflow.py b/src/llamafactory/train/dpo/workflow.py index 431b5285..c004363a 100644 --- a/src/llamafactory/train/dpo/workflow.py +++ b/src/llamafactory/train/dpo/workflow.py @@ -41,7 +41,7 @@ def run_dpo( ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module) + dataset_module = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) data_collator = PairwiseDataCollatorWithPadding( @@ -71,7 +71,7 @@ def run_dpo( data_collator=data_collator, callbacks=callbacks, **tokenizer_module, - **split_dataset(dataset, data_args, training_args), + **dataset_module, ) # Training diff --git a/src/llamafactory/train/kto/workflow.py b/src/llamafactory/train/kto/workflow.py index 8182a184..b2d0c82e 100644 --- a/src/llamafactory/train/kto/workflow.py +++ b/src/llamafactory/train/kto/workflow.py @@ -41,7 +41,7 @@ def run_kto( ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - dataset = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module) + dataset_module = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) data_collator = KTODataCollatorWithPadding( @@ -68,7 +68,7 @@ def run_kto( data_collator=data_collator, callbacks=callbacks, **tokenizer_module, - **split_dataset(dataset, data_args, training_args), + **dataset_module, ) # Training diff --git a/src/llamafactory/train/ppo/workflow.py b/src/llamafactory/train/ppo/workflow.py index f52b80d6..53d9f18f 100644 --- a/src/llamafactory/train/ppo/workflow.py +++ b/src/llamafactory/train/ppo/workflow.py @@ -43,7 +43,7 @@ def run_ppo( ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - dataset = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module) + dataset_module = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training @@ -63,7 +63,7 @@ def run_ppo( model=model, reward_model=reward_model, ref_model=ref_model, - dataset=dataset, + dataset=dataset_module["train_dataset"], data_collator=data_collator, **tokenizer_module, ) diff --git a/src/llamafactory/train/pt/workflow.py b/src/llamafactory/train/pt/workflow.py index b84a0e7d..2f27d6cd 100644 --- a/src/llamafactory/train/pt/workflow.py +++ b/src/llamafactory/train/pt/workflow.py @@ -42,7 +42,7 @@ def run_pt( ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - dataset = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module) + dataset_module = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) @@ -54,7 +54,7 @@ def run_pt( data_collator=data_collator, callbacks=callbacks, **tokenizer_module, - **split_dataset(dataset, data_args, training_args), + **dataset_module, ) # Training diff --git a/src/llamafactory/train/rm/workflow.py b/src/llamafactory/train/rm/workflow.py index 384814cc..54fa7fd0 100644 --- a/src/llamafactory/train/rm/workflow.py +++ b/src/llamafactory/train/rm/workflow.py @@ -41,7 +41,7 @@ def run_rm( ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module) + dataset_module = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) @@ -57,7 +57,7 @@ def run_rm( callbacks=callbacks, compute_metrics=compute_accuracy, **tokenizer_module, - **split_dataset(dataset, data_args, training_args), + **dataset_module, ) # Training @@ -81,7 +81,7 @@ def run_rm( # Predict if training_args.do_predict: - predict_results = trainer.predict(dataset, metric_key_prefix="predict") + predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict") trainer.log_metrics("predict", predict_results.metrics) trainer.save_metrics("predict", predict_results.metrics) trainer.save_predictions(predict_results) diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index dea3c1a8..b0bacc33 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -43,7 +43,7 @@ def run_sft( ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) + dataset_module = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) if training_args.predict_with_generate: @@ -76,7 +76,7 @@ def run_sft( compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else compute_accuracy, preprocess_logits_for_metrics=None if training_args.predict_with_generate else eval_logit_processor, **tokenizer_module, - **split_dataset(dataset, data_args, training_args), + **dataset_module, ) # Keyword arguments for `model.generate` @@ -105,12 +105,12 @@ def run_sft( # Predict if training_args.do_predict: - predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs) + predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs) if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled predict_results.metrics.pop("predict_loss", None) trainer.log_metrics("predict", predict_results.metrics) trainer.save_metrics("predict", predict_results.metrics) - trainer.save_predictions(dataset, predict_results) + trainer.save_predictions(dataset_module["eval_dataset"], predict_results) # Create model card create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) diff --git a/tests/data/test_supervised.py b/tests/data/test_supervised.py index 9cb49615..7ad52ee8 100644 --- a/tests/data/test_supervised.py +++ b/tests/data/test_supervised.py @@ -47,7 +47,7 @@ def test_supervised(num_samples: int): model_args, data_args, training_args, _, _ = get_train_args(TRAIN_ARGS) tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - tokenized_data = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) + dataset_module = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) @@ -63,5 +63,5 @@ def test_supervised(num_samples: int): {"role": "assistant", "content": original_data[index]["output"]}, ] templated_result = ref_tokenizer.apply_chat_template(messages, tokenize=False) - decoded_result = tokenizer.decode(tokenized_data["input_ids"][index]) + decoded_result = tokenizer.decode(dataset_module["train_dataset"]["input_ids"][index]) assert templated_result == decoded_result From ddbd848e49ba30910e354c12c538132eb615a798 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Sun, 14 Jul 2024 21:27:04 +0800 Subject: [PATCH 2/8] Update README.md Former-commit-id: 9d64507bd5d47f096e81c90bfb347690afaaec2b --- data/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/data/README.md b/data/README.md index 0f14bef8..5a34bcbe 100644 --- a/data/README.md +++ b/data/README.md @@ -11,9 +11,9 @@ Currently we support datasets in **alpaca** and **sharegpt** format. "formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})", "ranking": "whether the dataset is a preference dataset or not. (default: False)", "subset": "the name of the subset. (optional, default: None)", + "split": "the name of dataset split to be used. (optional, default: train)", "folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)", - "num_samples": "the number of samples in the dataset used for training. (optional, default: None)", - "split": "which dataset split to use for training and evaluation (optional, default: train)", + "num_samples": "the number of samples in the dataset to be used. (optional, default: None)", "columns (optional)": { "prompt": "the column name in the dataset containing the prompts. (default: instruction)", "query": "the column name in the dataset containing the queries. (default: input)", From 140b512426a9a61bbb19a81df07850563460abb8 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Sun, 14 Jul 2024 23:04:34 +0800 Subject: [PATCH 3/8] Update parser.py Former-commit-id: 3d39d74003c4ca36f9c9b77f622d366383b0af7e --- src/llamafactory/data/parser.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/llamafactory/data/parser.py b/src/llamafactory/data/parser.py index c810ec8b..c443b9d9 100644 --- a/src/llamafactory/data/parser.py +++ b/src/llamafactory/data/parser.py @@ -38,9 +38,9 @@ class DatasetAttr: ranking: bool = False # extra configs subset: Optional[str] = None + split: str = "train" folder: Optional[str] = None num_samples: Optional[int] = None - split: Optional[str] = "train" # common columns system: Optional[str] = None tools: Optional[str] = None @@ -72,7 +72,7 @@ class DatasetAttr: setattr(self, key, obj.get(key, default)) -def get_dataset_list(data_args: "DataArguments", dataset: "str" = None) -> List["DatasetAttr"]: +def get_dataset_list(data_args: "DataArguments", dataset: Optional[str]) -> List["DatasetAttr"]: if dataset is not None: dataset_names = [ds.strip() for ds in dataset.split(",")] else: @@ -121,10 +121,9 @@ def get_dataset_list(data_args: "DataArguments", dataset: "str" = None) -> List[ dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca") dataset_attr.set_attr("ranking", dataset_info[name], default=False) dataset_attr.set_attr("subset", dataset_info[name]) + dataset_attr.set_attr("split", dataset_info[name], default="train") dataset_attr.set_attr("folder", dataset_info[name]) dataset_attr.set_attr("num_samples", dataset_info[name]) - if "split" in dataset_info[name]: - dataset_attr.set_attr("split", dataset_info[name]) if "columns" in dataset_info[name]: column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"] From 2e9c9471daa67850f912c85e21441898837f52ce Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 15 Jul 2024 00:50:06 +0800 Subject: [PATCH 4/8] Update loader.py Former-commit-id: a5b809516e7de1d6d5f4583089fee3028d0db01d --- src/llamafactory/data/loader.py | 249 ++++++++++++++++++-------------- 1 file changed, 139 insertions(+), 110 deletions(-) diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index d527d7d2..069ea199 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import os import sys -from typing import TYPE_CHECKING, Literal, Optional, Union, Dict +from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union import numpy as np -from datasets import load_dataset, load_from_disk +from datasets import DatasetDict, load_dataset, load_from_disk +from transformers.utils.versions import require_version from ..extras.constants import FILEEXT2TYPE from ..extras.logging import get_logger @@ -27,7 +27,7 @@ from .aligner import align_dataset from .data_utils import merge_dataset, split_dataset from .parser import get_dataset_list from .preprocess import get_preprocess_and_print_func -from .template import get_template_and_fix_tokenizer, Template +from .template import get_template_and_fix_tokenizer if TYPE_CHECKING: @@ -35,13 +35,15 @@ if TYPE_CHECKING: from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments from ..hparams import DataArguments, ModelArguments + from .data_utils import DatasetModule from .parser import DatasetAttr + from .template import Template logger = get_logger(__name__) -def load_single_dataset( +def _load_single_dataset( dataset_attr: "DatasetAttr", model_args: "ModelArguments", data_args: "DataArguments", @@ -81,31 +83,24 @@ def load_single_dataset( raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from)) if dataset_attr.load_from == "ms_hub": - try: - from modelscope import MsDataset - from modelscope.utils.config_ds import MS_DATASETS_CACHE + require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") + from modelscope import MsDataset + from modelscope.utils.config_ds import MS_DATASETS_CACHE - cache_dir = model_args.cache_dir or MS_DATASETS_CACHE - dataset = MsDataset.load( - dataset_name=data_path, - subset_name=data_name, - data_dir=data_dir, - data_files=data_files, - split=dataset_attr.split, - cache_dir=cache_dir, - token=model_args.ms_hub_token, - use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")), - ) - if isinstance(dataset, MsDataset): - dataset = dataset.to_hf_dataset() - except ImportError: - raise ImportError("Please install modelscope via `pip install modelscope -U`") + cache_dir = model_args.cache_dir or MS_DATASETS_CACHE + dataset = MsDataset.load( + dataset_name=data_path, + subset_name=data_name, + data_dir=data_dir, + data_files=data_files, + split=dataset_attr.split, + cache_dir=cache_dir, + token=model_args.ms_hub_token, + use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")), + ) + if isinstance(dataset, MsDataset): + dataset = dataset.to_hf_dataset() else: - if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0 - kwargs = {"trust_remote_code": True} - else: - kwargs = {} - dataset = load_dataset( path=data_path, name=data_name, @@ -115,7 +110,7 @@ def load_single_dataset( cache_dir=model_args.cache_dir, token=model_args.hf_hub_token, streaming=(data_args.streaming and (dataset_attr.load_from != "file")), - **kwargs, + trust_remote_code=True, ) if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True @@ -140,90 +135,64 @@ def load_single_dataset( return align_dataset(dataset, dataset_attr, data_args, training_args) -def load_and_preprocess( +def _get_merged_dataset( + dataset_names: Optional[Sequence[str]], model_args: "ModelArguments", data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo", "kto"], - tokenizer: "PreTrainedTokenizer", +) -> Optional[Union["Dataset", "IterableDataset"]]: + if dataset_names is None: + return None + + datasets = [] + for dataset_attr in get_dataset_list(dataset_names, data_args.dataset_dir): + if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True): + raise ValueError("The dataset is not applicable in the current training stage.") + + datasets.append(_load_single_dataset(dataset_attr, model_args, data_args, training_args)) + + return merge_dataset(datasets, data_args, seed=training_args.seed) + + +def _get_preprocessed_dataset( + dataset: Optional[Union["Dataset", "IterableDataset"]], + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + stage: Literal["pt", "sft", "rm", "ppo", "kto"], template: "Template", + tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"] = None, - is_eval: bool = False -) -> Union["Dataset", "IterableDataset"]: - if not is_eval and data_args.tokenized_path is not None: - if has_tokenized_data(data_args.tokenized_path): - logger.warning("Loading dataset from disk will ignore other data arguments.") - dataset = load_from_disk(data_args.tokenized_path) - logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path)) - if data_args.streaming: - dataset = dataset.to_iterable_dataset() - return dataset + is_eval: bool = False, +) -> Optional[Union["Dataset", "IterableDataset"]]: + if dataset is None: + return None - if data_args.streaming: - raise ValueError("Turn off `streaming` when saving dataset to disk.") - - if is_eval and data_args.eval_tokenized_path is not None: - if has_tokenized_data(data_args.eval_tokenized_path): - logger.warning("Loading dataset from disk will ignore other data arguments.") - dataset = load_from_disk(data_args.eval_tokenized_path) - logger.info("Loaded tokenized dataset from {}.".format(data_args.eval_tokenized_path)) - if data_args.streaming: - dataset = dataset.to_iterable_dataset() - return dataset - - if data_args.streaming: - raise ValueError("Turn off `streaming` when saving dataset to disk.") - - with training_args.main_process_first(desc="load dataset"): - all_datasets = [] - for dataset_attr in get_dataset_list(data_args, data_args.eval_dataset if is_eval else data_args.dataset): - if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True): - raise ValueError("The dataset is not applicable in the current training stage.") - - all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args, training_args)) - - dataset = merge_dataset(all_datasets, data_args, training_args) - - with training_args.main_process_first(desc="pre-process dataset"): - preprocess_func, print_function = get_preprocess_and_print_func( - data_args, training_args, stage, template, tokenizer, processor + preprocess_func, print_function = get_preprocess_and_print_func( + data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval) + ) + column_names = list(next(iter(dataset)).keys()) + kwargs = {} + if not data_args.streaming: + kwargs = dict( + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), + desc="Running tokenizer on dataset", ) - column_names = list(next(iter(dataset)).keys()) - kwargs = {} - if not data_args.streaming: - kwargs = dict( - num_proc=data_args.preprocessing_num_workers, - load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), - desc="Running tokenizer on dataset", - ) - dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs) + dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs) - if not is_eval and data_args.tokenized_path is not None: - if training_args.should_save: - dataset.save_to_disk(data_args.tokenized_path) - logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path)) - logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path)) + if training_args.should_log: + try: + print("eval example:" if is_eval else "training example:") + print_function(next(iter(dataset))) + except StopIteration: + if stage == "pt": + raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.") + else: + raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.") - sys.exit(0) - if is_eval and data_args.eval_tokenized_path is not None: - if training_args.should_save: - dataset.save_to_disk(data_args.eval_tokenized_path) - logger.info("Tokenized dataset saved at {}.".format(data_args.eval_tokenized_path)) - logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.eval_tokenized_path)) - - sys.exit(0) - - if training_args.should_log: - try: - print_function(next(iter(dataset))) - except StopIteration: - if stage == "pt": - raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.") - else: - raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.") - - return dataset + return dataset def get_dataset( @@ -232,16 +201,76 @@ def get_dataset( training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo", "kto"], tokenizer: "PreTrainedTokenizer", - processor: Optional["ProcessorMixin"] = None -) -> Dict[str, "Dataset"]: + processor: Optional["ProcessorMixin"] = None, +) -> "DatasetModule": template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format) if data_args.train_on_prompt and template.efficient_eos: raise ValueError("Current template does not support `train_on_prompt`.") - train_dataset = load_and_preprocess(model_args, data_args, training_args, stage, tokenizer, template, processor) + # Load tokenized dataset + if data_args.tokenized_path is not None: + if has_tokenized_data(data_args.tokenized_path): + logger.warning("Loading dataset from disk will ignore other data arguments.") + dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path) + logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path)) - if data_args.eval_dataset or data_args.eval_tokenized_path: - eval_dataset = load_and_preprocess(model_args, data_args, training_args, stage, tokenizer, template, processor, True) - return {"train_dataset": train_dataset, "eval_dataset": eval_dataset} - else: - return split_dataset(train_dataset, data_args, training_args) + dataset_module: Dict[str, "Dataset"] = {} + if "train" in dataset_dict: + dataset_module["train_dataset"] = dataset_dict["train"] + if "validation" in dataset_dict: + dataset_module["eval_dataset"] = dataset_dict["validation"] + + if data_args.streaming: + dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()} + + return dataset_module + + if data_args.streaming: + raise ValueError("Turn off `streaming` when saving dataset to disk.") + + # Load and preprocess dataset + with training_args.main_process_first(desc="load dataset"): + dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage) + eval_dataset = _get_merged_dataset(data_args.eval_dataset, model_args, data_args, training_args, stage) + + with training_args.main_process_first(desc="pre-process dataset"): + dataset = _get_preprocessed_dataset( + dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False + ) + eval_dataset = _get_preprocessed_dataset( + eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True + ) + + if data_args.val_size > 1e-6: + dataset_dict = split_dataset(dataset, data_args, seed=training_args.seed) + else: + dataset_dict = {} + if dataset is not None: + if data_args.streaming: + dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) + + dataset_dict["train"] = dataset + + if eval_dataset is not None: + if data_args.streaming: + eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) + + dataset_dict["validation"] = eval_dataset + + dataset_dict = DatasetDict(dataset_dict) + + if data_args.tokenized_path is not None: + if training_args.should_save: + dataset_dict.save_to_disk(data_args.tokenized_path) + logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path)) + logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path)) + + sys.exit(0) + + dataset_module = {} + if "train" in dataset_dict: + dataset_module["train_dataset"] = dataset_dict["train"] + if "validation" in dataset_dict: + dataset_module["eval_dataset"] = dataset_dict["validation"] + + return dataset_module From 5633c0ab1e0c2ec818940df2accd1f8bbd800958 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 15 Jul 2024 00:54:34 +0800 Subject: [PATCH 5/8] Update data_utils.py Former-commit-id: 97a0e291c79f145950b54a11d03d81ada4784d22 --- src/llamafactory/data/data_utils.py | 51 ++++++++++++++--------------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/src/llamafactory/data/data_utils.py b/src/llamafactory/data/data_utils.py index 76ded47e..4666aabc 100644 --- a/src/llamafactory/data/data_utils.py +++ b/src/llamafactory/data/data_utils.py @@ -13,16 +13,15 @@ # limitations under the License. from enum import Enum, unique -from typing import TYPE_CHECKING, Dict, List, Sequence, Set, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict, Union -from datasets import concatenate_datasets, interleave_datasets +from datasets import DatasetDict, concatenate_datasets, interleave_datasets from ..extras.logging import get_logger if TYPE_CHECKING: from datasets import Dataset, IterableDataset - from transformers import Seq2SeqTrainingArguments from ..hparams import DataArguments @@ -42,24 +41,29 @@ class Role(str, Enum): OBSERVATION = "observation" +class DatasetModule(TypedDict): + train_dataset: Optional[Union["Dataset", "IterableDataset"]] + eval_dataset: Optional[Union["Dataset", "IterableDataset"]] + + def merge_dataset( - all_datasets: List[Union["Dataset", "IterableDataset"]], - data_args: "DataArguments", - training_args: "Seq2SeqTrainingArguments", + all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int ) -> Union["Dataset", "IterableDataset"]: if len(all_datasets) == 1: return all_datasets[0] elif data_args.mix_strategy == "concat": if data_args.streaming: logger.warning("The samples between different datasets will not be mixed in streaming mode.") + return concatenate_datasets(all_datasets) elif data_args.mix_strategy.startswith("interleave"): if not data_args.streaming: logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.") + return interleave_datasets( datasets=all_datasets, probabilities=data_args.interleave_probs, - seed=training_args.seed, + seed=seed, stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted", ) else: @@ -67,22 +71,17 @@ def merge_dataset( def split_dataset( - dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments" -) -> Dict[str, "Dataset"]: - if training_args.do_train: - if data_args.val_size > 1e-6: # Split the dataset - if data_args.streaming: - dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) - val_set = dataset.take(int(data_args.val_size)) - train_set = dataset.skip(int(data_args.val_size)) - return {"train_dataset": train_set, "eval_dataset": val_set} - else: - val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size - dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed) - return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]} - else: - if data_args.streaming: - dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) - return {"train_dataset": dataset} - else: # do_eval or do_predict - return {"eval_dataset": dataset} + dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int +) -> "DatasetDict": + r""" + Splits the dataset and returns a dataset dict containing train set (required) and validation set (optional). + """ + if data_args.streaming: + dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed) + val_set = dataset.take(int(data_args.val_size)) + train_set = dataset.skip(int(data_args.val_size)) + return DatasetDict({"train": train_set, "validation": val_set}) + else: + val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size + dataset = dataset.train_test_split(test_size=val_size, seed=seed) + return DatasetDict({"train": dataset["train"], "validation": dataset["test"]}) From eed7cbb453ce9a53a8811504d03903b17cc329a0 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 15 Jul 2024 00:55:21 +0800 Subject: [PATCH 6/8] Update parser.py Former-commit-id: 84e4047f8a1f78256be65f3f7bddce358ed9e882 --- src/llamafactory/data/parser.py | 38 ++++++++++++++++----------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/llamafactory/data/parser.py b/src/llamafactory/data/parser.py index c443b9d9..2dccfc5d 100644 --- a/src/llamafactory/data/parser.py +++ b/src/llamafactory/data/parser.py @@ -15,16 +15,14 @@ import json import os from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Sequence + +from transformers.utils import cached_file from ..extras.constants import DATA_CONFIG from ..extras.misc import use_modelscope -if TYPE_CHECKING: - from ..hparams import DataArguments - - @dataclass class DatasetAttr: r""" @@ -72,31 +70,33 @@ class DatasetAttr: setattr(self, key, obj.get(key, default)) -def get_dataset_list(data_args: "DataArguments", dataset: Optional[str]) -> List["DatasetAttr"]: - if dataset is not None: - dataset_names = [ds.strip() for ds in dataset.split(",")] - else: +def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]: + r""" + Gets the attributes of the datasets. + """ + if dataset_names is None: dataset_names = [] - if data_args.dataset_dir == "ONLINE": + if dataset_dir == "ONLINE": dataset_info = None else: + if dataset_dir.startswith("REMOTE:"): + config_path = cached_file(path_or_repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset") + else: + config_path = os.path.join(dataset_dir, DATA_CONFIG) + try: - with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f: + with open(config_path, "r") as f: dataset_info = json.load(f) except Exception as err: if len(dataset_names) != 0: - raise ValueError( - "Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err)) - ) + raise ValueError("Cannot open {} due to {}.".format(config_path, str(err))) + dataset_info = None - if data_args.interleave_probs is not None: - data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")] - - dataset_list: List[DatasetAttr] = [] + dataset_list: List["DatasetAttr"] = [] for name in dataset_names: - if dataset_info is None: + if dataset_info is None: # dataset_dir is ONLINE load_from = "ms_hub" if use_modelscope() else "hf_hub" dataset_attr = DatasetAttr(load_from, dataset_name=name) dataset_list.append(dataset_attr) From 30a3c6e886e07591ae83b07c568d4ae17a701a6e Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 15 Jul 2024 00:55:36 +0800 Subject: [PATCH 7/8] Update preprocess.py Former-commit-id: df52fb05b1b08887288bbaab7c612b7ac27c2290 --- src/llamafactory/data/preprocess.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py index 2ea2fa1d..caf4a9b8 100644 --- a/src/llamafactory/data/preprocess.py +++ b/src/llamafactory/data/preprocess.py @@ -27,7 +27,7 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu if TYPE_CHECKING: - from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments + from transformers import PreTrainedTokenizer, ProcessorMixin from ..hparams import DataArguments from .template import Template @@ -35,11 +35,11 @@ if TYPE_CHECKING: def get_preprocess_and_print_func( data_args: "DataArguments", - training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo", "kto"], template: "Template", tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], + do_generate: bool = False, ) -> Tuple[Callable, Callable]: if stage == "pt": preprocess_func = partial( @@ -48,7 +48,7 @@ def get_preprocess_and_print_func( data_args=data_args, ) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) - elif stage == "sft" and not training_args.predict_with_generate: + elif stage == "sft" and not do_generate: if data_args.packing: if data_args.neat_packing: from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence From 788dc1c67919a401cce0c96304830be995a7b0b9 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 15 Jul 2024 00:56:03 +0800 Subject: [PATCH 8/8] Update data_args.py Former-commit-id: cba673f491c5d97aba62aea03f310bd54fb3fe28 --- src/llamafactory/hparams/data_args.py | 34 +++++++++++++++++++++------ 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 7f7e62cd..f483099d 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -31,12 +31,11 @@ class DataArguments: ) dataset: Optional[str] = field( default=None, - metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}, + metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."}, ) eval_dataset: Optional[str] = field( default=None, - metadata={"help": "The name of provided dataset(s) to use for eval during training. " - "Use commas to separate multiple datasets."}, + metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."}, ) dataset_dir: str = field( default="data", @@ -110,12 +109,33 @@ class DataArguments: default=None, metadata={"help": "Path to save or load the tokenized datasets."}, ) - eval_tokenized_path: Optional[str] = field( - default=None, - metadata={"help": "Path to save or load the tokenized eval datasets."}, - ) def __post_init__(self): + def split_arg(arg): + if isinstance(arg, str): + return [item.strip() for item in arg.split(",")] + return arg + + self.dataset = split_arg(self.dataset) + self.eval_dataset = split_arg(self.eval_dataset) + + if self.dataset is None and self.val_size > 1e-6: + raise ValueError("Cannot specify `val_size` if `dataset` is None.") + + if self.eval_dataset is not None and self.val_size > 1e-6: + raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.") + + if self.interleave_probs is not None: + if self.mix_strategy == "concat": + raise ValueError("`interleave_probs` is only valid for interleaved mixing.") + + self.interleave_probs = list(map(float, split_arg(self.interleave_probs))) + if self.dataset is not None and len(self.dataset) != len(self.interleave_probs): + raise ValueError("The length of dataset and interleave probs should be identical.") + + if self.eval_dataset is not None and len(self.eval_dataset) != len(self.interleave_probs): + raise ValueError("The length of eval dataset and interleave probs should be identical.") + if self.streaming and self.val_size > 1e-6 and self.val_size < 1: raise ValueError("Streaming mode should have an integer val size.")