mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 11:42:49 +08:00
tiny fix
Former-commit-id: 0c22da4f1cc710b471f6d511d50ce878521173ca
This commit is contained in:
parent
aba4268607
commit
1b02915d19
@ -69,25 +69,24 @@ def _load_single_dataset(
|
|||||||
if os.path.isdir(local_path): # is directory
|
if os.path.isdir(local_path): # is directory
|
||||||
for file_name in os.listdir(local_path):
|
for file_name in os.listdir(local_path):
|
||||||
data_files.append(os.path.join(local_path, file_name))
|
data_files.append(os.path.join(local_path, file_name))
|
||||||
if data_path is None:
|
|
||||||
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
|
|
||||||
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
|
|
||||||
raise ValueError("File types should be identical.")
|
|
||||||
elif os.path.isfile(local_path): # is file
|
elif os.path.isfile(local_path): # is file
|
||||||
data_files.append(local_path)
|
data_files.append(local_path)
|
||||||
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"File {local_path} not found.")
|
raise ValueError(f"File {local_path} not found.")
|
||||||
|
|
||||||
|
data_path = FILEEXT2TYPE.get(os.path.splitext(data_files[0])[-1][1:], None)
|
||||||
if data_path is None:
|
if data_path is None:
|
||||||
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
|
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
|
||||||
|
|
||||||
|
if any(data_path != FILEEXT2TYPE.get(os.path.splitext(data_file)[-1][1:], None) for data_file in data_files):
|
||||||
|
raise ValueError("File types should be identical.")
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
|
raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
|
||||||
|
|
||||||
if dataset_attr.load_from == "ms_hub":
|
if dataset_attr.load_from == "ms_hub":
|
||||||
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
|
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
|
||||||
from modelscope import MsDataset
|
from modelscope import MsDataset # type: ignore
|
||||||
from modelscope.utils.config_ds import MS_DATASETS_CACHE
|
from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
|
||||||
|
|
||||||
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
||||||
dataset = MsDataset.load(
|
dataset = MsDataset.load(
|
||||||
@ -98,15 +97,15 @@ def _load_single_dataset(
|
|||||||
split=dataset_attr.split,
|
split=dataset_attr.split,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
token=model_args.ms_hub_token,
|
token=model_args.ms_hub_token,
|
||||||
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
use_streaming=data_args.streaming,
|
||||||
)
|
)
|
||||||
if isinstance(dataset, MsDataset):
|
if isinstance(dataset, MsDataset):
|
||||||
dataset = dataset.to_hf_dataset()
|
dataset = dataset.to_hf_dataset()
|
||||||
|
|
||||||
elif dataset_attr.load_from == "om_hub":
|
elif dataset_attr.load_from == "om_hub":
|
||||||
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
|
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
|
||||||
from openmind import OmDataset
|
from openmind import OmDataset # type: ignore
|
||||||
from openmind.utils.hub import OM_DATASETS_CACHE
|
from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore
|
||||||
|
|
||||||
cache_dir = model_args.cache_dir or OM_DATASETS_CACHE
|
cache_dir = model_args.cache_dir or OM_DATASETS_CACHE
|
||||||
dataset = OmDataset.load_dataset(
|
dataset = OmDataset.load_dataset(
|
||||||
@ -117,7 +116,7 @@ def _load_single_dataset(
|
|||||||
split=dataset_attr.split,
|
split=dataset_attr.split,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
token=model_args.om_hub_token,
|
token=model_args.om_hub_token,
|
||||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
streaming=data_args.streaming,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
@ -128,13 +127,10 @@ def _load_single_dataset(
|
|||||||
split=dataset_attr.split,
|
split=dataset_attr.split,
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
token=model_args.hf_hub_token,
|
token=model_args.hf_hub_token,
|
||||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
streaming=data_args.streaming,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
|
||||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
|
||||||
|
|
||||||
if dataset_attr.num_samples is not None and not data_args.streaming:
|
if dataset_attr.num_samples is not None and not data_args.streaming:
|
||||||
target_num = dataset_attr.num_samples
|
target_num = dataset_attr.num_samples
|
||||||
indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
|
indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
|
||||||
|
@ -471,9 +471,7 @@ class PixtralPlugin(BasePlugin):
|
|||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
if image_input_sizes is None:
|
if image_input_sizes is None:
|
||||||
raise ValueError(
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||||
"The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)
|
|
||||||
)
|
|
||||||
|
|
||||||
image_size = image_input_sizes[0][num_image_tokens]
|
image_size = image_input_sizes[0][num_image_tokens]
|
||||||
height, width = image_size
|
height, width = image_size
|
||||||
@ -489,7 +487,7 @@ class PixtralPlugin(BasePlugin):
|
|||||||
message["content"] = content
|
message["content"] = content
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
if len(images) != num_image_tokens:
|
||||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
@ -356,10 +356,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
|||||||
r"""
|
r"""
|
||||||
Gets chat template and fixes the tokenizer.
|
Gets chat template and fixes the tokenizer.
|
||||||
"""
|
"""
|
||||||
if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
|
|
||||||
require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
|
|
||||||
require_version("accelerate>=0.34.0", "To fix: pip install accelerate>=0.34.0")
|
|
||||||
|
|
||||||
if data_args.template is None:
|
if data_args.template is None:
|
||||||
template = TEMPLATES["empty"] # placeholder
|
template = TEMPLATES["empty"] # placeholder
|
||||||
else:
|
else:
|
||||||
@ -367,6 +363,9 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
|||||||
if template is None:
|
if template is None:
|
||||||
raise ValueError(f"Template {data_args.template} does not exist.")
|
raise ValueError(f"Template {data_args.template} does not exist.")
|
||||||
|
|
||||||
|
if template.mm_plugin.__class__.__name__ != "BasePlugin":
|
||||||
|
require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
|
||||||
|
|
||||||
if data_args.train_on_prompt and template.efficient_eos:
|
if data_args.train_on_prompt and template.efficient_eos:
|
||||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user