merge data part to the text stream

This commit is contained in:
BUAADreamer
2024-04-25 19:19:59 +08:00
parent 838eb87a96
commit c6dd89918f
15 changed files with 828 additions and 293 deletions

View File

@@ -1,6 +1,6 @@
import inspect
import os
from typing import TYPE_CHECKING, Literal, Union
from typing import TYPE_CHECKING, Literal, Union, Optional
from datasets import load_dataset, load_from_disk
@@ -25,9 +25,9 @@ logger = get_logger(__name__)
def load_single_dataset(
dataset_attr: "DatasetAttr",
model_args: "ModelArguments",
data_args: "DataArguments",
dataset_attr: "DatasetAttr",
model_args: "ModelArguments",
data_args: "DataArguments",
) -> Union["Dataset", "IterableDataset"]:
logger.info("Loading dataset {}...".format(dataset_attr))
data_path, data_name, data_dir, data_files = None, None, None, None
@@ -78,14 +78,20 @@ def load_single_dataset(
split=data_args.split,
cache_dir=cache_dir,
token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
use_streaming=(
data_args.streaming and (dataset_attr.load_from != "file")
),
)
if isinstance(dataset, MsDataset):
dataset = dataset.to_hf_dataset()
except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`")
raise ImportError(
"Please install modelscope via `pip install modelscope -U`"
)
else:
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
if (
"trust_remote_code" in inspect.signature(load_dataset).parameters
): # for datasets==2.16.0
kwargs = {"trust_remote_code": True}
else:
kwargs = {}
@@ -102,7 +108,9 @@ def load_single_dataset(
**kwargs,
)
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
if data_args.streaming and (
dataset_attr.load_from == "file"
): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if data_args.max_samples is not None: # truncate dataset
@@ -113,11 +121,12 @@ def load_single_dataset(
def get_dataset(
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
processor: Optional["AutoProcessor"] = None,
) -> Union["Dataset", "IterableDataset"]:
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
if data_args.train_on_prompt and template.efficient_eos:
@@ -126,9 +135,13 @@ def get_dataset(
# Load tokenized dataset
if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
logger.warning(
"Loading dataset from disk will ignore other data arguments."
)
dataset = load_from_disk(data_args.tokenized_path)
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
logger.info(
"Loaded tokenized dataset from {}.".format(data_args.tokenized_path)
)
if data_args.streaming:
dataset = dataset.to_iterable_dataset()
return dataset
@@ -139,15 +152,21 @@ def get_dataset(
with training_args.main_process_first(desc="load dataset"):
all_datasets = []
for dataset_attr in get_dataset_list(data_args):
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
raise ValueError("The dataset is not applicable in the current training stage.")
if (stage == "rm" and dataset_attr.ranking is False) or (
stage != "rm" and dataset_attr.ranking is True
):
raise ValueError(
"The dataset is not applicable in the current training stage."
)
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
all_datasets.append(
load_single_dataset(dataset_attr, model_args, data_args)
)
dataset = merge_dataset(all_datasets, data_args, training_args)
with training_args.main_process_first(desc="pre-process dataset"):
preprocess_func, print_function = get_preprocess_and_print_func(
tokenizer, template, data_args, training_args, stage
tokenizer, template, data_args, training_args, stage, processor
)
column_names = list(next(iter(dataset)).keys())
kwargs = {}
@@ -158,13 +177,21 @@ def get_dataset(
desc="Running tokenizer on dataset",
)
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
dataset = dataset.map(
preprocess_func, batched=True, remove_columns=column_names, **kwargs
)
if data_args.tokenized_path is not None:
if training_args.should_save:
dataset.save_to_disk(data_args.tokenized_path)
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
logger.info("Please restart the training with `--tokenized_path {}`.".format(data_args.tokenized_path))
logger.info(
"Tokenized dataset saved at {}.".format(data_args.tokenized_path)
)
logger.info(
"Please restart the training with `--tokenized_path {}`.".format(
data_args.tokenized_path
)
)
exit(0)
@@ -172,34 +199,8 @@ def get_dataset(
try:
print_function(next(iter(dataset)))
except StopIteration:
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
raise RuntimeError(
"Cannot find valid samples, check `data/README.md` for the data format."
)
return dataset
def get_mm_dataset(
processor: "AutoProcessor",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
) -> Union["Dataset", "IterableDataset"]:
if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.tokenized_path)
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
if data_args.streaming:
dataset = dataset.to_iterable_dataset()
return dataset
if data_args.streaming:
raise ValueError("Turn off `streaming` when saving dataset to disk.")
with training_args.main_process_first(desc="load dataset"):
all_datasets = []
for dataset_attr in get_dataset_list(data_args):
all_datasets.append(load_dataset(dataset_attr.dataset_name)['train'])
dataset = merge_dataset(all_datasets, data_args, training_args)
return dataset