mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 11:20:35 +08:00
merge data part to the text stream
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import inspect
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Literal, Union
|
||||
from typing import TYPE_CHECKING, Literal, Union, Optional
|
||||
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
@@ -25,9 +25,9 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
def load_single_dataset(
|
||||
dataset_attr: "DatasetAttr",
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
dataset_attr: "DatasetAttr",
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||
@@ -78,14 +78,20 @@ def load_single_dataset(
|
||||
split=data_args.split,
|
||||
cache_dir=cache_dir,
|
||||
token=model_args.ms_hub_token,
|
||||
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||
use_streaming=(
|
||||
data_args.streaming and (dataset_attr.load_from != "file")
|
||||
),
|
||||
)
|
||||
if isinstance(dataset, MsDataset):
|
||||
dataset = dataset.to_hf_dataset()
|
||||
except ImportError:
|
||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
||||
raise ImportError(
|
||||
"Please install modelscope via `pip install modelscope -U`"
|
||||
)
|
||||
else:
|
||||
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
|
||||
if (
|
||||
"trust_remote_code" in inspect.signature(load_dataset).parameters
|
||||
): # for datasets==2.16.0
|
||||
kwargs = {"trust_remote_code": True}
|
||||
else:
|
||||
kwargs = {}
|
||||
@@ -102,7 +108,9 @@ def load_single_dataset(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
||||
if data_args.streaming and (
|
||||
dataset_attr.load_from == "file"
|
||||
): # faster than specifying streaming=True
|
||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
||||
|
||||
if data_args.max_samples is not None: # truncate dataset
|
||||
@@ -113,11 +121,12 @@ def load_single_dataset(
|
||||
|
||||
|
||||
def get_dataset(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||
processor: Optional["AutoProcessor"] = None,
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
|
||||
if data_args.train_on_prompt and template.efficient_eos:
|
||||
@@ -126,9 +135,13 @@ def get_dataset(
|
||||
# Load tokenized dataset
|
||||
if data_args.tokenized_path is not None:
|
||||
if has_tokenized_data(data_args.tokenized_path):
|
||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||
logger.warning(
|
||||
"Loading dataset from disk will ignore other data arguments."
|
||||
)
|
||||
dataset = load_from_disk(data_args.tokenized_path)
|
||||
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
|
||||
logger.info(
|
||||
"Loaded tokenized dataset from {}.".format(data_args.tokenized_path)
|
||||
)
|
||||
if data_args.streaming:
|
||||
dataset = dataset.to_iterable_dataset()
|
||||
return dataset
|
||||
@@ -139,15 +152,21 @@ def get_dataset(
|
||||
with training_args.main_process_first(desc="load dataset"):
|
||||
all_datasets = []
|
||||
for dataset_attr in get_dataset_list(data_args):
|
||||
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
||||
raise ValueError("The dataset is not applicable in the current training stage.")
|
||||
if (stage == "rm" and dataset_attr.ranking is False) or (
|
||||
stage != "rm" and dataset_attr.ranking is True
|
||||
):
|
||||
raise ValueError(
|
||||
"The dataset is not applicable in the current training stage."
|
||||
)
|
||||
|
||||
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
|
||||
all_datasets.append(
|
||||
load_single_dataset(dataset_attr, model_args, data_args)
|
||||
)
|
||||
dataset = merge_dataset(all_datasets, data_args, training_args)
|
||||
|
||||
with training_args.main_process_first(desc="pre-process dataset"):
|
||||
preprocess_func, print_function = get_preprocess_and_print_func(
|
||||
tokenizer, template, data_args, training_args, stage
|
||||
tokenizer, template, data_args, training_args, stage, processor
|
||||
)
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
kwargs = {}
|
||||
@@ -158,13 +177,21 @@ def get_dataset(
|
||||
desc="Running tokenizer on dataset",
|
||||
)
|
||||
|
||||
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
|
||||
dataset = dataset.map(
|
||||
preprocess_func, batched=True, remove_columns=column_names, **kwargs
|
||||
)
|
||||
|
||||
if data_args.tokenized_path is not None:
|
||||
if training_args.should_save:
|
||||
dataset.save_to_disk(data_args.tokenized_path)
|
||||
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
|
||||
logger.info("Please restart the training with `--tokenized_path {}`.".format(data_args.tokenized_path))
|
||||
logger.info(
|
||||
"Tokenized dataset saved at {}.".format(data_args.tokenized_path)
|
||||
)
|
||||
logger.info(
|
||||
"Please restart the training with `--tokenized_path {}`.".format(
|
||||
data_args.tokenized_path
|
||||
)
|
||||
)
|
||||
|
||||
exit(0)
|
||||
|
||||
@@ -172,34 +199,8 @@ def get_dataset(
|
||||
try:
|
||||
print_function(next(iter(dataset)))
|
||||
except StopIteration:
|
||||
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
|
||||
raise RuntimeError(
|
||||
"Cannot find valid samples, check `data/README.md` for the data format."
|
||||
)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def get_mm_dataset(
|
||||
processor: "AutoProcessor",
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
if data_args.tokenized_path is not None:
|
||||
if has_tokenized_data(data_args.tokenized_path):
|
||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||
dataset = load_from_disk(data_args.tokenized_path)
|
||||
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
|
||||
if data_args.streaming:
|
||||
dataset = dataset.to_iterable_dataset()
|
||||
return dataset
|
||||
|
||||
if data_args.streaming:
|
||||
raise ValueError("Turn off `streaming` when saving dataset to disk.")
|
||||
|
||||
with training_args.main_process_first(desc="load dataset"):
|
||||
all_datasets = []
|
||||
for dataset_attr in get_dataset_list(data_args):
|
||||
all_datasets.append(load_dataset(dataset_attr.dataset_name)['train'])
|
||||
dataset = merge_dataset(all_datasets, data_args, training_args)
|
||||
|
||||
return dataset
|
||||
|
||||
Reference in New Issue
Block a user