Former-commit-id: 0c22da4f1cc710b471f6d511d50ce878521173ca
This commit is contained in:
hiyouga 2024-10-30 08:56:29 +00:00
parent aba4268607
commit 1b02915d19
3 changed files with 16 additions and 23 deletions

View File

@ -69,25 +69,24 @@ def _load_single_dataset(
if os.path.isdir(local_path): # is directory if os.path.isdir(local_path): # is directory
for file_name in os.listdir(local_path): for file_name in os.listdir(local_path):
data_files.append(os.path.join(local_path, file_name)) data_files.append(os.path.join(local_path, file_name))
if data_path is None:
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
raise ValueError("File types should be identical.")
elif os.path.isfile(local_path): # is file elif os.path.isfile(local_path): # is file
data_files.append(local_path) data_files.append(local_path)
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
else: else:
raise ValueError(f"File {local_path} not found.") raise ValueError(f"File {local_path} not found.")
data_path = FILEEXT2TYPE.get(os.path.splitext(data_files[0])[-1][1:], None)
if data_path is None: if data_path is None:
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys()))) raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
if any(data_path != FILEEXT2TYPE.get(os.path.splitext(data_file)[-1][1:], None) for data_file in data_files):
raise ValueError("File types should be identical.")
else: else:
raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.") raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
if dataset_attr.load_from == "ms_hub": if dataset_attr.load_from == "ms_hub":
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
from modelscope import MsDataset from modelscope import MsDataset # type: ignore
from modelscope.utils.config_ds import MS_DATASETS_CACHE from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
dataset = MsDataset.load( dataset = MsDataset.load(
@ -98,15 +97,15 @@ def _load_single_dataset(
split=dataset_attr.split, split=dataset_attr.split,
cache_dir=cache_dir, cache_dir=cache_dir,
token=model_args.ms_hub_token, token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")), use_streaming=data_args.streaming,
) )
if isinstance(dataset, MsDataset): if isinstance(dataset, MsDataset):
dataset = dataset.to_hf_dataset() dataset = dataset.to_hf_dataset()
elif dataset_attr.load_from == "om_hub": elif dataset_attr.load_from == "om_hub":
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0") require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
from openmind import OmDataset from openmind import OmDataset # type: ignore
from openmind.utils.hub import OM_DATASETS_CACHE from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore
cache_dir = model_args.cache_dir or OM_DATASETS_CACHE cache_dir = model_args.cache_dir or OM_DATASETS_CACHE
dataset = OmDataset.load_dataset( dataset = OmDataset.load_dataset(
@ -117,7 +116,7 @@ def _load_single_dataset(
split=dataset_attr.split, split=dataset_attr.split,
cache_dir=cache_dir, cache_dir=cache_dir,
token=model_args.om_hub_token, token=model_args.om_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file")), streaming=data_args.streaming,
) )
else: else:
dataset = load_dataset( dataset = load_dataset(
@ -128,13 +127,10 @@ def _load_single_dataset(
split=dataset_attr.split, split=dataset_attr.split,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token, token=model_args.hf_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file")), streaming=data_args.streaming,
trust_remote_code=True, trust_remote_code=True,
) )
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if dataset_attr.num_samples is not None and not data_args.streaming: if dataset_attr.num_samples is not None and not data_args.streaming:
target_num = dataset_attr.num_samples target_num = dataset_attr.num_samples
indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included

View File

@ -471,9 +471,7 @@ class PixtralPlugin(BasePlugin):
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if image_input_sizes is None: if image_input_sizes is None:
raise ValueError( raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
"The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)
)
image_size = image_input_sizes[0][num_image_tokens] image_size = image_input_sizes[0][num_image_tokens]
height, width = image_size height, width = image_size
@ -489,7 +487,7 @@ class PixtralPlugin(BasePlugin):
message["content"] = content message["content"] = content
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
return messages return messages

View File

@ -356,10 +356,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
r""" r"""
Gets chat template and fixes the tokenizer. Gets chat template and fixes the tokenizer.
""" """
if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
require_version("accelerate>=0.34.0", "To fix: pip install accelerate>=0.34.0")
if data_args.template is None: if data_args.template is None:
template = TEMPLATES["empty"] # placeholder template = TEMPLATES["empty"] # placeholder
else: else:
@ -367,6 +363,9 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
if template is None: if template is None:
raise ValueError(f"Template {data_args.template} does not exist.") raise ValueError(f"Template {data_args.template} does not exist.")
if template.mm_plugin.__class__.__name__ != "BasePlugin":
require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
if data_args.train_on_prompt and template.efficient_eos: if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.") raise ValueError("Current template does not support `train_on_prompt`.")