modify style

This commit is contained in:
BUAADreamer
2024-04-25 21:15:16 +08:00
parent 43d7ad5ecc
commit 1dcabafe72
16 changed files with 374 additions and 502 deletions

View File

@@ -1,6 +1,6 @@
import inspect
import os
from typing import TYPE_CHECKING, Literal, Union, Optional
from typing import TYPE_CHECKING, Literal, Optional, Union
from datasets import load_dataset, load_from_disk
@@ -13,9 +13,10 @@ from .preprocess import get_preprocess_and_print_func
from .template import get_template_and_fix_tokenizer
from .utils import checksum, merge_dataset
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments, AutoProcessor
from transformers import AutoProcessor, Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments, ModelArguments
@@ -78,20 +79,14 @@ def load_single_dataset(
split=data_args.split,
cache_dir=cache_dir,
token=model_args.ms_hub_token,
use_streaming=(
data_args.streaming and (dataset_attr.load_from != "file")
),
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
)
if isinstance(dataset, MsDataset):
dataset = dataset.to_hf_dataset()
except ImportError:
raise ImportError(
"Please install modelscope via `pip install modelscope -U`"
)
raise ImportError("Please install modelscope via `pip install modelscope -U`")
else:
if (
"trust_remote_code" in inspect.signature(load_dataset).parameters
): # for datasets==2.16.0
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
kwargs = {"trust_remote_code": True}
else:
kwargs = {}
@@ -108,9 +103,7 @@ def load_single_dataset(
**kwargs,
)
if data_args.streaming and (
dataset_attr.load_from == "file"
): # faster than specifying streaming=True
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if data_args.max_samples is not None: # truncate dataset
@@ -135,13 +128,9 @@ def get_dataset(
# Load tokenized dataset
if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):
logger.warning(
"Loading dataset from disk will ignore other data arguments."
)
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.tokenized_path)
logger.info(
"Loaded tokenized dataset from {}.".format(data_args.tokenized_path)
)
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
if data_args.streaming:
dataset = dataset.to_iterable_dataset()
return dataset
@@ -152,16 +141,10 @@ def get_dataset(
with training_args.main_process_first(desc="load dataset"):
all_datasets = []
for dataset_attr in get_dataset_list(data_args):
if (stage == "rm" and dataset_attr.ranking is False) or (
stage != "rm" and dataset_attr.ranking is True
):
raise ValueError(
"The dataset is not applicable in the current training stage."
)
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
raise ValueError("The dataset is not applicable in the current training stage.")
all_datasets.append(
load_single_dataset(dataset_attr, model_args, data_args)
)
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
dataset = merge_dataset(all_datasets, data_args, training_args)
with training_args.main_process_first(desc="pre-process dataset"):
@@ -177,21 +160,13 @@ def get_dataset(
desc="Running tokenizer on dataset",
)
dataset = dataset.map(
preprocess_func, batched=True, remove_columns=column_names, **kwargs
)
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
if data_args.tokenized_path is not None:
if training_args.should_save:
dataset.save_to_disk(data_args.tokenized_path)
logger.info(
"Tokenized dataset saved at {}.".format(data_args.tokenized_path)
)
logger.info(
"Please restart the training with `--tokenized_path {}`.".format(
data_args.tokenized_path
)
)
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
logger.info("Please restart the training with `--tokenized_path {}`.".format(data_args.tokenized_path))
exit(0)
@@ -199,8 +174,6 @@ def get_dataset(
try:
print_function(next(iter(dataset)))
except StopIteration:
raise RuntimeError(
"Cannot find valid samples, check `data/README.md` for the data format."
)
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
return dataset