From ae09c6c214d957e82f010547ccc3803e60a9b91d Mon Sep 17 00:00:00 2001 From: wangdepeng Date: Wed, 27 Nov 2024 16:44:42 +0800 Subject: [PATCH] fix:tokenized_path not None and load_from_disk return Dataset Trigger stuck Former-commit-id: 4424d4de8aca0e4d3b92672584978f3cc3fc33da --- src/llamafactory/data/loader.py | 34 +++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 1c3a7747..abdc716d 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -14,11 +14,13 @@ import os import sys +import zstandard as zstd from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union import numpy as np -from datasets import DatasetDict, load_dataset, load_from_disk +from datasets import DatasetDict, Dataset, load_dataset, load_from_disk from transformers.utils.versions import require_version +from typing import Union from ..extras import logging from ..extras.constants import FILEEXT2TYPE @@ -52,6 +54,7 @@ def _load_single_dataset( Loads a single dataset and aligns it to the standard format. """ logger.info_rank0(f"Loading dataset {dataset_attr}...") + data_path, data_name, data_dir, data_files = None, None, None, None if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]: data_path = dataset_attr.dataset_name @@ -119,6 +122,7 @@ def _load_single_dataset( streaming=data_args.streaming, ) else: + logger.info(f"Loading dataset {data_path}, {data_name}, {data_dir}, {data_files},{dataset_attr.split},{model_args.cache_dir},{model_args.hf_hub_token},{data_args.streaming}...") dataset = load_dataset( path=data_path, name=data_name, @@ -239,20 +243,30 @@ def get_dataset( if data_args.tokenized_path is not None: if has_tokenized_data(data_args.tokenized_path): logger.warning_rank0("Loading dataset from disk will ignore other data arguments.") - dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path) + dataset_dict: Union[DatasetDict, Dataset] = load_from_disk(data_args.tokenized_path) logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.") - dataset_module: Dict[str, "Dataset"] = {} - if "train" in dataset_dict: - dataset_module["train_dataset"] = dataset_dict["train"] - if "validation" in dataset_dict: - dataset_module["eval_dataset"] = dataset_dict["validation"] + if isinstance(dataset_dict, DatasetDict): + print(f"Loaded DatasetDict with keys: {dataset_dict.keys()}") + if "train" in dataset_dict: + dataset_module["train_dataset"] = dataset_dict["train"] - if data_args.streaming: - dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()} + if "validation" in dataset_dict: + dataset_module["eval_dataset"] = dataset_dict["validation"] - return dataset_module + if data_args.streaming: + dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()} + + logger.info_rank0(f"Finished tokenized dataset load from {data_args.tokenized_path}.") + return dataset_module + + elif isinstance(dataset_dict, Dataset): + logger.info_rank0(f"Loaded single Dataset with {len(dataset_dict)} samples.") + dataset_module["train_dataset"] = dataset_dict + return dataset_module + else: + raise ValueError("Unknown dataset type loaded!") if data_args.streaming: raise ValueError("Turn off `streaming` when saving dataset to disk.")