diff --git a/src/llamafactory/data/parser.py b/src/llamafactory/data/parser.py index 3c692784..b1dc6e6c 100644 --- a/src/llamafactory/data/parser.py +++ b/src/llamafactory/data/parser.py @@ -15,7 +15,7 @@ import json import os from dataclasses import dataclass -from typing import Any, Literal, Optional +from typing import Any, Literal, Optional, Union from huggingface_hub import hf_hub_download @@ -90,12 +90,14 @@ class DatasetAttr: self.set_attr(tag, attr["tags"]) -def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> list["DatasetAttr"]: +def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: Union[str, dict]) -> list["DatasetAttr"]: r"""Get the attributes of the datasets.""" if dataset_names is None: dataset_names = [] - if dataset_dir == "ONLINE": + if isinstance(dataset_dir, dict): + dataset_info = dataset_dir + elif dataset_dir == "ONLINE": dataset_info = None else: if dataset_dir.startswith("REMOTE:"): diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index e6844733..4c8351f8 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -16,7 +16,7 @@ # limitations under the License. from dataclasses import asdict, dataclass, field -from typing import Any, Literal, Optional +from typing import Any, Literal, Optional, Union @dataclass @@ -35,7 +35,7 @@ class DataArguments: default=None, metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."}, ) - dataset_dir: str = field( + dataset_dir: Union[str, dict] = field( default="data", metadata={"help": "Path to the folder containing the datasets."}, )