Update parser.py

Former-commit-id: 84e4047f8a1f78256be65f3f7bddce358ed9e882
This commit is contained in:
hoshi-hiyouga 2024-07-15 00:55:21 +08:00 committed by GitHub
parent 5633c0ab1e
commit eed7cbb453

View File

@ -15,16 +15,14 @@
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional from typing import Any, Dict, List, Literal, Optional, Sequence
from transformers.utils import cached_file
from ..extras.constants import DATA_CONFIG from ..extras.constants import DATA_CONFIG
from ..extras.misc import use_modelscope from ..extras.misc import use_modelscope
if TYPE_CHECKING:
from ..hparams import DataArguments
@dataclass @dataclass
class DatasetAttr: class DatasetAttr:
r""" r"""
@ -72,31 +70,33 @@ class DatasetAttr:
setattr(self, key, obj.get(key, default)) setattr(self, key, obj.get(key, default))
def get_dataset_list(data_args: "DataArguments", dataset: Optional[str]) -> List["DatasetAttr"]: def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]:
if dataset is not None: r"""
dataset_names = [ds.strip() for ds in dataset.split(",")] Gets the attributes of the datasets.
else: """
if dataset_names is None:
dataset_names = [] dataset_names = []
if data_args.dataset_dir == "ONLINE": if dataset_dir == "ONLINE":
dataset_info = None dataset_info = None
else: else:
if dataset_dir.startswith("REMOTE:"):
config_path = cached_file(path_or_repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
else:
config_path = os.path.join(dataset_dir, DATA_CONFIG)
try: try:
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f: with open(config_path, "r") as f:
dataset_info = json.load(f) dataset_info = json.load(f)
except Exception as err: except Exception as err:
if len(dataset_names) != 0: if len(dataset_names) != 0:
raise ValueError( raise ValueError("Cannot open {} due to {}.".format(config_path, str(err)))
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
)
dataset_info = None dataset_info = None
if data_args.interleave_probs is not None: dataset_list: List["DatasetAttr"] = []
data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")]
dataset_list: List[DatasetAttr] = []
for name in dataset_names: for name in dataset_names:
if dataset_info is None: if dataset_info is None: # dataset_dir is ONLINE
load_from = "ms_hub" if use_modelscope() else "hf_hub" load_from = "ms_hub" if use_modelscope() else "hf_hub"
dataset_attr = DatasetAttr(load_from, dataset_name=name) dataset_attr = DatasetAttr(load_from, dataset_name=name)
dataset_list.append(dataset_attr) dataset_list.append(dataset_attr)