From 92940817e7f862ee485d81207e4e3e4e05733dde Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Wed, 4 Dec 2024 22:08:27 +0800 Subject: [PATCH] lint Former-commit-id: 6a5074e46695378b76d58aac8ad7768b6b034b9c --- src/llamafactory/data/loader.py | 36 ++++++++++++--------------------- 1 file changed, 13 insertions(+), 23 deletions(-) diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index abdc716d..4924f98e 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -14,13 +14,11 @@ import os import sys -import zstandard as zstd from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union import numpy as np -from datasets import DatasetDict, Dataset, load_dataset, load_from_disk +from datasets import DatasetDict, load_dataset, load_from_disk from transformers.utils.versions import require_version -from typing import Union from ..extras import logging from ..extras.constants import FILEEXT2TYPE @@ -54,7 +52,6 @@ def _load_single_dataset( Loads a single dataset and aligns it to the standard format. """ logger.info_rank0(f"Loading dataset {dataset_attr}...") - data_path, data_name, data_dir, data_files = None, None, None, None if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]: data_path = dataset_attr.dataset_name @@ -122,7 +119,6 @@ def _load_single_dataset( streaming=data_args.streaming, ) else: - logger.info(f"Loading dataset {data_path}, {data_name}, {data_dir}, {data_files},{dataset_attr.split},{model_args.cache_dir},{model_args.hf_hub_token},{data_args.streaming}...") dataset = load_dataset( path=data_path, name=data_name, @@ -243,30 +239,24 @@ def get_dataset( if data_args.tokenized_path is not None: if has_tokenized_data(data_args.tokenized_path): logger.warning_rank0("Loading dataset from disk will ignore other data arguments.") - dataset_dict: Union[DatasetDict, Dataset] = load_from_disk(data_args.tokenized_path) + tokenized_data: Union["Dataset", "DatasetDict"] = load_from_disk(data_args.tokenized_path) logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.") + dataset_module: Dict[str, "Dataset"] = {} + if isinstance(tokenized_data, DatasetDict): + if "train" in tokenized_data: + dataset_module["train_dataset"] = tokenized_data["train"] - if isinstance(dataset_dict, DatasetDict): - print(f"Loaded DatasetDict with keys: {dataset_dict.keys()}") - if "train" in dataset_dict: - dataset_module["train_dataset"] = dataset_dict["train"] + if "validation" in tokenized_data: + dataset_module["eval_dataset"] = tokenized_data["validation"] - if "validation" in dataset_dict: - dataset_module["eval_dataset"] = dataset_dict["validation"] + else: # Dataset + dataset_module["train_dataset"] = tokenized_data - if data_args.streaming: - dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()} + if data_args.streaming: + dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()} - logger.info_rank0(f"Finished tokenized dataset load from {data_args.tokenized_path}.") - return dataset_module - - elif isinstance(dataset_dict, Dataset): - logger.info_rank0(f"Loaded single Dataset with {len(dataset_dict)} samples.") - dataset_module["train_dataset"] = dataset_dict - return dataset_module - else: - raise ValueError("Unknown dataset type loaded!") + return dataset_module if data_args.streaming: raise ValueError("Turn off `streaming` when saving dataset to disk.")