mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-18 12:50:38 +08:00
107 lines
4.1 KiB
Python
107 lines
4.1 KiB
Python
import os
|
|
import hashlib
|
|
from typing import List
|
|
|
|
from datasets import Dataset, concatenate_datasets, load_dataset
|
|
|
|
from llmtuner.extras.logging import get_logger
|
|
from llmtuner.hparams import ModelArguments, DataArguments
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def get_dataset(
|
|
model_args: ModelArguments,
|
|
data_args: DataArguments
|
|
) -> Dataset:
|
|
|
|
def checksum(file_path, hash):
|
|
with open(file_path, "rb") as datafile:
|
|
binary_data = datafile.read()
|
|
sha1 = hashlib.sha1(binary_data).hexdigest()
|
|
if sha1 != hash:
|
|
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
|
|
|
|
ext2type = {
|
|
"csv": "csv",
|
|
"json": "json",
|
|
"jsonl": "json",
|
|
"txt": "text"
|
|
}
|
|
|
|
max_samples = data_args.max_samples
|
|
all_datasets: List[Dataset] = [] # support multiple datasets
|
|
|
|
for dataset_attr in data_args.dataset_list:
|
|
|
|
logger.info("Loading dataset {}...".format(dataset_attr))
|
|
|
|
if dataset_attr.load_from == "hf_hub":
|
|
data_path = dataset_attr.dataset_name
|
|
data_files = None
|
|
elif dataset_attr.load_from == "script":
|
|
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
|
data_files = None
|
|
elif dataset_attr.load_from == "file":
|
|
data_path = None
|
|
data_files: List[str] = []
|
|
|
|
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
|
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
|
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
|
|
|
|
if data_path is None:
|
|
data_path = ext2type.get(data_files[0].split(".")[-1], None)
|
|
else:
|
|
assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match."
|
|
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
|
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
|
|
data_path = ext2type.get(data_files[0].split(".")[-1], None)
|
|
else:
|
|
raise ValueError("File not found.")
|
|
|
|
assert data_path, "File extension must be txt, csv, json or jsonl."
|
|
|
|
if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None:
|
|
checksum(data_files[0], dataset_attr.dataset_sha1)
|
|
else:
|
|
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.")
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
raw_datasets = load_dataset(
|
|
data_path,
|
|
data_files=data_files,
|
|
cache_dir=model_args.cache_dir,
|
|
use_auth_token=True if model_args.use_auth_token else None
|
|
)
|
|
dataset = raw_datasets[data_args.split]
|
|
|
|
if max_samples is not None:
|
|
max_samples_temp = min(len(dataset), max_samples)
|
|
dataset = dataset.select(range(max_samples_temp))
|
|
|
|
dummy_data = [None] * len(dataset)
|
|
prefix_data = [dataset_attr.source_prefix] * len(dataset)
|
|
for column_name, target_name in [
|
|
("prompt_column", "prompt"),
|
|
("query_column", "query"),
|
|
("response_column", "response"),
|
|
("history_column", "history")
|
|
]: # every dataset will have 4 columns same as each other
|
|
if getattr(dataset_attr, column_name) != target_name:
|
|
if getattr(dataset_attr, column_name):
|
|
dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
|
|
else: # None or empty string
|
|
dataset = dataset.add_column(target_name, dummy_data)
|
|
dataset = dataset.add_column("prefix", prefix_data)
|
|
all_datasets.append(dataset)
|
|
|
|
if len(data_args.dataset_list) == 1:
|
|
all_datasets = all_datasets[0]
|
|
else:
|
|
all_datasets = concatenate_datasets(all_datasets)
|
|
|
|
return all_datasets
|