mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-24 06:42:52 +08:00
Update parser.py
Former-commit-id: 84e4047f8a1f78256be65f3f7bddce358ed9e882
This commit is contained in:
parent
5633c0ab1e
commit
eed7cbb453
@ -15,16 +15,14 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||||
|
|
||||||
|
from transformers.utils import cached_file
|
||||||
|
|
||||||
from ..extras.constants import DATA_CONFIG
|
from ..extras.constants import DATA_CONFIG
|
||||||
from ..extras.misc import use_modelscope
|
from ..extras.misc import use_modelscope
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ..hparams import DataArguments
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DatasetAttr:
|
class DatasetAttr:
|
||||||
r"""
|
r"""
|
||||||
@ -72,31 +70,33 @@ class DatasetAttr:
|
|||||||
setattr(self, key, obj.get(key, default))
|
setattr(self, key, obj.get(key, default))
|
||||||
|
|
||||||
|
|
||||||
def get_dataset_list(data_args: "DataArguments", dataset: Optional[str]) -> List["DatasetAttr"]:
|
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]:
|
||||||
if dataset is not None:
|
r"""
|
||||||
dataset_names = [ds.strip() for ds in dataset.split(",")]
|
Gets the attributes of the datasets.
|
||||||
else:
|
"""
|
||||||
|
if dataset_names is None:
|
||||||
dataset_names = []
|
dataset_names = []
|
||||||
|
|
||||||
if data_args.dataset_dir == "ONLINE":
|
if dataset_dir == "ONLINE":
|
||||||
dataset_info = None
|
dataset_info = None
|
||||||
else:
|
else:
|
||||||
|
if dataset_dir.startswith("REMOTE:"):
|
||||||
|
config_path = cached_file(path_or_repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
|
||||||
|
else:
|
||||||
|
config_path = os.path.join(dataset_dir, DATA_CONFIG)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
|
with open(config_path, "r") as f:
|
||||||
dataset_info = json.load(f)
|
dataset_info = json.load(f)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
if len(dataset_names) != 0:
|
if len(dataset_names) != 0:
|
||||||
raise ValueError(
|
raise ValueError("Cannot open {} due to {}.".format(config_path, str(err)))
|
||||||
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
|
|
||||||
)
|
|
||||||
dataset_info = None
|
dataset_info = None
|
||||||
|
|
||||||
if data_args.interleave_probs is not None:
|
dataset_list: List["DatasetAttr"] = []
|
||||||
data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")]
|
|
||||||
|
|
||||||
dataset_list: List[DatasetAttr] = []
|
|
||||||
for name in dataset_names:
|
for name in dataset_names:
|
||||||
if dataset_info is None:
|
if dataset_info is None: # dataset_dir is ONLINE
|
||||||
load_from = "ms_hub" if use_modelscope() else "hf_hub"
|
load_from = "ms_hub" if use_modelscope() else "hf_hub"
|
||||||
dataset_attr = DatasetAttr(load_from, dataset_name=name)
|
dataset_attr = DatasetAttr(load_from, dataset_name=name)
|
||||||
dataset_list.append(dataset_attr)
|
dataset_list.append(dataset_attr)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user