mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 04:02:49 +08:00
lint
Former-commit-id: 6a5074e46695378b76d58aac8ad7768b6b034b9c
This commit is contained in:
parent
ae09c6c214
commit
92940817e7
@ -14,13 +14,11 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import zstandard as zstd
|
|
||||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
|
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import DatasetDict, Dataset, load_dataset, load_from_disk
|
from datasets import DatasetDict, load_dataset, load_from_disk
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.constants import FILEEXT2TYPE
|
from ..extras.constants import FILEEXT2TYPE
|
||||||
@ -54,7 +52,6 @@ def _load_single_dataset(
|
|||||||
Loads a single dataset and aligns it to the standard format.
|
Loads a single dataset and aligns it to the standard format.
|
||||||
"""
|
"""
|
||||||
logger.info_rank0(f"Loading dataset {dataset_attr}...")
|
logger.info_rank0(f"Loading dataset {dataset_attr}...")
|
||||||
|
|
||||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||||
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
|
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
|
||||||
data_path = dataset_attr.dataset_name
|
data_path = dataset_attr.dataset_name
|
||||||
@ -122,7 +119,6 @@ def _load_single_dataset(
|
|||||||
streaming=data_args.streaming,
|
streaming=data_args.streaming,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(f"Loading dataset {data_path}, {data_name}, {data_dir}, {data_files},{dataset_attr.split},{model_args.cache_dir},{model_args.hf_hub_token},{data_args.streaming}...")
|
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
path=data_path,
|
path=data_path,
|
||||||
name=data_name,
|
name=data_name,
|
||||||
@ -243,31 +239,25 @@ def get_dataset(
|
|||||||
if data_args.tokenized_path is not None:
|
if data_args.tokenized_path is not None:
|
||||||
if has_tokenized_data(data_args.tokenized_path):
|
if has_tokenized_data(data_args.tokenized_path):
|
||||||
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
|
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
|
||||||
dataset_dict: Union[DatasetDict, Dataset] = load_from_disk(data_args.tokenized_path)
|
tokenized_data: Union["Dataset", "DatasetDict"] = load_from_disk(data_args.tokenized_path)
|
||||||
logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
|
logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
|
||||||
|
|
||||||
dataset_module: Dict[str, "Dataset"] = {}
|
dataset_module: Dict[str, "Dataset"] = {}
|
||||||
|
if isinstance(tokenized_data, DatasetDict):
|
||||||
|
if "train" in tokenized_data:
|
||||||
|
dataset_module["train_dataset"] = tokenized_data["train"]
|
||||||
|
|
||||||
if isinstance(dataset_dict, DatasetDict):
|
if "validation" in tokenized_data:
|
||||||
print(f"Loaded DatasetDict with keys: {dataset_dict.keys()}")
|
dataset_module["eval_dataset"] = tokenized_data["validation"]
|
||||||
if "train" in dataset_dict:
|
|
||||||
dataset_module["train_dataset"] = dataset_dict["train"]
|
|
||||||
|
|
||||||
if "validation" in dataset_dict:
|
else: # Dataset
|
||||||
dataset_module["eval_dataset"] = dataset_dict["validation"]
|
dataset_module["train_dataset"] = tokenized_data
|
||||||
|
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()}
|
dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()}
|
||||||
|
|
||||||
logger.info_rank0(f"Finished tokenized dataset load from {data_args.tokenized_path}.")
|
|
||||||
return dataset_module
|
return dataset_module
|
||||||
|
|
||||||
elif isinstance(dataset_dict, Dataset):
|
|
||||||
logger.info_rank0(f"Loaded single Dataset with {len(dataset_dict)} samples.")
|
|
||||||
dataset_module["train_dataset"] = dataset_dict
|
|
||||||
return dataset_module
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown dataset type loaded!")
|
|
||||||
|
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
raise ValueError("Turn off `streaming` when saving dataset to disk.")
|
raise ValueError("Turn off `streaming` when saving dataset to disk.")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user