Former-commit-id: 0c22da4f1cc710b471f6d511d50ce878521173ca
This commit is contained in:
hiyouga 2024-10-30 08:56:29 +00:00
parent aba4268607
commit 1b02915d19
3 changed files with 16 additions and 23 deletions

View File

@ -69,25 +69,24 @@ def _load_single_dataset(
if os.path.isdir(local_path): # is directory
for file_name in os.listdir(local_path):
data_files.append(os.path.join(local_path, file_name))
if data_path is None:
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
raise ValueError("File types should be identical.")
elif os.path.isfile(local_path): # is file
data_files.append(local_path)
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
else:
raise ValueError(f"File {local_path} not found.")
data_path = FILEEXT2TYPE.get(os.path.splitext(data_files[0])[-1][1:], None)
if data_path is None:
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
if any(data_path != FILEEXT2TYPE.get(os.path.splitext(data_file)[-1][1:], None) for data_file in data_files):
raise ValueError("File types should be identical.")
else:
raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
if dataset_attr.load_from == "ms_hub":
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
from modelscope import MsDataset
from modelscope.utils.config_ds import MS_DATASETS_CACHE
from modelscope import MsDataset # type: ignore
from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
dataset = MsDataset.load(
@ -98,15 +97,15 @@ def _load_single_dataset(
split=dataset_attr.split,
cache_dir=cache_dir,
token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
use_streaming=data_args.streaming,
)
if isinstance(dataset, MsDataset):
dataset = dataset.to_hf_dataset()
elif dataset_attr.load_from == "om_hub":
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
from openmind import OmDataset
from openmind.utils.hub import OM_DATASETS_CACHE
from openmind import OmDataset # type: ignore
from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore
cache_dir = model_args.cache_dir or OM_DATASETS_CACHE
dataset = OmDataset.load_dataset(
@ -117,7 +116,7 @@ def _load_single_dataset(
split=dataset_attr.split,
cache_dir=cache_dir,
token=model_args.om_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
streaming=data_args.streaming,
)
else:
dataset = load_dataset(
@ -128,13 +127,10 @@ def _load_single_dataset(
split=dataset_attr.split,
cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
streaming=data_args.streaming,
trust_remote_code=True,
)
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if dataset_attr.num_samples is not None and not data_args.streaming:
target_num = dataset_attr.num_samples
indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included

View File

@ -471,9 +471,7 @@ class PixtralPlugin(BasePlugin):
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if image_input_sizes is None:
raise ValueError(
"The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)
)
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
image_size = image_input_sizes[0][num_image_tokens]
height, width = image_size
@ -489,7 +487,7 @@ class PixtralPlugin(BasePlugin):
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
return messages

View File

@ -356,10 +356,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
r"""
Gets chat template and fixes the tokenizer.
"""
if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
require_version("accelerate>=0.34.0", "To fix: pip install accelerate>=0.34.0")
if data_args.template is None:
template = TEMPLATES["empty"] # placeholder
else:
@ -367,6 +363,9 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
if template is None:
raise ValueError(f"Template {data_args.template} does not exist.")
if template.mm_plugin.__class__.__name__ != "BasePlugin":
require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")