diff --git a/src/llamafactory/data/parser.py b/src/llamafactory/data/parser.py index c443b9d9..2dccfc5d 100644 --- a/src/llamafactory/data/parser.py +++ b/src/llamafactory/data/parser.py @@ -15,16 +15,14 @@ import json import os from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Sequence + +from transformers.utils import cached_file from ..extras.constants import DATA_CONFIG from ..extras.misc import use_modelscope -if TYPE_CHECKING: - from ..hparams import DataArguments - - @dataclass class DatasetAttr: r""" @@ -72,31 +70,33 @@ class DatasetAttr: setattr(self, key, obj.get(key, default)) -def get_dataset_list(data_args: "DataArguments", dataset: Optional[str]) -> List["DatasetAttr"]: - if dataset is not None: - dataset_names = [ds.strip() for ds in dataset.split(",")] - else: +def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]: + r""" + Gets the attributes of the datasets. + """ + if dataset_names is None: dataset_names = [] - if data_args.dataset_dir == "ONLINE": + if dataset_dir == "ONLINE": dataset_info = None else: + if dataset_dir.startswith("REMOTE:"): + config_path = cached_file(path_or_repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset") + else: + config_path = os.path.join(dataset_dir, DATA_CONFIG) + try: - with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f: + with open(config_path, "r") as f: dataset_info = json.load(f) except Exception as err: if len(dataset_names) != 0: - raise ValueError( - "Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err)) - ) + raise ValueError("Cannot open {} due to {}.".format(config_path, str(err))) + dataset_info = None - if data_args.interleave_probs is not None: - data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")] - - dataset_list: List[DatasetAttr] = [] + dataset_list: List["DatasetAttr"] = [] for name in dataset_names: - if dataset_info is None: + if dataset_info is None: # dataset_dir is ONLINE load_from = "ms_hub" if use_modelscope() else "hf_hub" dataset_attr = DatasetAttr(load_from, dataset_name=name) dataset_list.append(dataset_attr)