mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
lint
Former-commit-id: 6a5074e46695378b76d58aac8ad7768b6b034b9c
This commit is contained in:
parent
ae09c6c214
commit
92940817e7
@ -14,13 +14,11 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import zstandard as zstd
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
from datasets import DatasetDict, Dataset, load_dataset, load_from_disk
|
||||
from datasets import DatasetDict, load_dataset, load_from_disk
|
||||
from transformers.utils.versions import require_version
|
||||
from typing import Union
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import FILEEXT2TYPE
|
||||
@ -54,7 +52,6 @@ def _load_single_dataset(
|
||||
Loads a single dataset and aligns it to the standard format.
|
||||
"""
|
||||
logger.info_rank0(f"Loading dataset {dataset_attr}...")
|
||||
|
||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
|
||||
data_path = dataset_attr.dataset_name
|
||||
@ -122,7 +119,6 @@ def _load_single_dataset(
|
||||
streaming=data_args.streaming,
|
||||
)
|
||||
else:
|
||||
logger.info(f"Loading dataset {data_path}, {data_name}, {data_dir}, {data_files},{dataset_attr.split},{model_args.cache_dir},{model_args.hf_hub_token},{data_args.streaming}...")
|
||||
dataset = load_dataset(
|
||||
path=data_path,
|
||||
name=data_name,
|
||||
@ -243,30 +239,24 @@ def get_dataset(
|
||||
if data_args.tokenized_path is not None:
|
||||
if has_tokenized_data(data_args.tokenized_path):
|
||||
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
|
||||
dataset_dict: Union[DatasetDict, Dataset] = load_from_disk(data_args.tokenized_path)
|
||||
tokenized_data: Union["Dataset", "DatasetDict"] = load_from_disk(data_args.tokenized_path)
|
||||
logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
|
||||
|
||||
dataset_module: Dict[str, "Dataset"] = {}
|
||||
if isinstance(tokenized_data, DatasetDict):
|
||||
if "train" in tokenized_data:
|
||||
dataset_module["train_dataset"] = tokenized_data["train"]
|
||||
|
||||
if isinstance(dataset_dict, DatasetDict):
|
||||
print(f"Loaded DatasetDict with keys: {dataset_dict.keys()}")
|
||||
if "train" in dataset_dict:
|
||||
dataset_module["train_dataset"] = dataset_dict["train"]
|
||||
if "validation" in tokenized_data:
|
||||
dataset_module["eval_dataset"] = tokenized_data["validation"]
|
||||
|
||||
if "validation" in dataset_dict:
|
||||
dataset_module["eval_dataset"] = dataset_dict["validation"]
|
||||
else: # Dataset
|
||||
dataset_module["train_dataset"] = tokenized_data
|
||||
|
||||
if data_args.streaming:
|
||||
dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()}
|
||||
if data_args.streaming:
|
||||
dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()}
|
||||
|
||||
logger.info_rank0(f"Finished tokenized dataset load from {data_args.tokenized_path}.")
|
||||
return dataset_module
|
||||
|
||||
elif isinstance(dataset_dict, Dataset):
|
||||
logger.info_rank0(f"Loaded single Dataset with {len(dataset_dict)} samples.")
|
||||
dataset_module["train_dataset"] = dataset_dict
|
||||
return dataset_module
|
||||
else:
|
||||
raise ValueError("Unknown dataset type loaded!")
|
||||
return dataset_module
|
||||
|
||||
if data_args.streaming:
|
||||
raise ValueError("Turn off `streaming` when saving dataset to disk.")
|
||||
|
Loading…
x
Reference in New Issue
Block a user