From 1b02915d192973fc9515aaa9e113464c559854af Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 30 Oct 2024 08:56:29 +0000 Subject: [PATCH] tiny fix Former-commit-id: 0c22da4f1cc710b471f6d511d50ce878521173ca --- src/llamafactory/data/loader.py | 26 +++++++++++--------------- src/llamafactory/data/mm_plugin.py | 6 ++---- src/llamafactory/data/template.py | 7 +++---- 3 files changed, 16 insertions(+), 23 deletions(-) diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 2ac13c94..1cb9c686 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -69,25 +69,24 @@ def _load_single_dataset( if os.path.isdir(local_path): # is directory for file_name in os.listdir(local_path): data_files.append(os.path.join(local_path, file_name)) - if data_path is None: - data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None) - elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None): - raise ValueError("File types should be identical.") elif os.path.isfile(local_path): # is file data_files.append(local_path) - data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None) else: raise ValueError(f"File {local_path} not found.") + data_path = FILEEXT2TYPE.get(os.path.splitext(data_files[0])[-1][1:], None) if data_path is None: raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys()))) + + if any(data_path != FILEEXT2TYPE.get(os.path.splitext(data_file)[-1][1:], None) for data_file in data_files): + raise ValueError("File types should be identical.") else: raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.") if dataset_attr.load_from == "ms_hub": require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") - from modelscope import MsDataset - from modelscope.utils.config_ds import MS_DATASETS_CACHE + from modelscope import MsDataset # type: ignore + from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore cache_dir = model_args.cache_dir or MS_DATASETS_CACHE dataset = MsDataset.load( @@ -98,15 +97,15 @@ def _load_single_dataset( split=dataset_attr.split, cache_dir=cache_dir, token=model_args.ms_hub_token, - use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")), + use_streaming=data_args.streaming, ) if isinstance(dataset, MsDataset): dataset = dataset.to_hf_dataset() elif dataset_attr.load_from == "om_hub": require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0") - from openmind import OmDataset - from openmind.utils.hub import OM_DATASETS_CACHE + from openmind import OmDataset # type: ignore + from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore cache_dir = model_args.cache_dir or OM_DATASETS_CACHE dataset = OmDataset.load_dataset( @@ -117,7 +116,7 @@ def _load_single_dataset( split=dataset_attr.split, cache_dir=cache_dir, token=model_args.om_hub_token, - streaming=(data_args.streaming and (dataset_attr.load_from != "file")), + streaming=data_args.streaming, ) else: dataset = load_dataset( @@ -128,13 +127,10 @@ def _load_single_dataset( split=dataset_attr.split, cache_dir=model_args.cache_dir, token=model_args.hf_hub_token, - streaming=(data_args.streaming and (dataset_attr.load_from != "file")), + streaming=data_args.streaming, trust_remote_code=True, ) - if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True - dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter - if dataset_attr.num_samples is not None and not data_args.streaming: target_num = dataset_attr.num_samples indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 52c65cb7..4e096c83 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -471,9 +471,7 @@ class PixtralPlugin(BasePlugin): content = message["content"] while IMAGE_PLACEHOLDER in content: if image_input_sizes is None: - raise ValueError( - "The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER) - ) + raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens") image_size = image_input_sizes[0][num_image_tokens] height, width = image_size @@ -489,7 +487,7 @@ class PixtralPlugin(BasePlugin): message["content"] = content if len(images) != num_image_tokens: - raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) + raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens") return messages diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index d0da3b30..bf07ec95 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -356,10 +356,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: r""" Gets chat template and fixes the tokenizer. """ - if data_args.template in ["llava", "paligemma", "qwen2_vl"]: - require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0") - require_version("accelerate>=0.34.0", "To fix: pip install accelerate>=0.34.0") - if data_args.template is None: template = TEMPLATES["empty"] # placeholder else: @@ -367,6 +363,9 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: if template is None: raise ValueError(f"Template {data_args.template} does not exist.") + if template.mm_plugin.__class__.__name__ != "BasePlugin": + require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0") + if data_args.train_on_prompt and template.efficient_eos: raise ValueError("Current template does not support `train_on_prompt`.")