modify style

This commit is contained in:
BUAADreamer
2024-04-25 21:15:16 +08:00
parent 43d7ad5ecc
commit 1dcabafe72
16 changed files with 374 additions and 502 deletions

View File

@@ -50,9 +50,7 @@ class DatasetAttr:
def __repr__(self) -> str:
return self.dataset_name
def set_attr(
self, key: str, obj: Dict[str, Any], default: Optional[Any] = None
) -> None:
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
setattr(self, key, obj.get(key, default))
@@ -71,16 +69,12 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
except Exception as err:
if len(dataset_names) != 0:
raise ValueError(
"Cannot open {} due to {}.".format(
os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err)
)
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
)
dataset_info = None
if data_args.interleave_probs is not None:
data_args.interleave_probs = [
float(prob.strip()) for prob in data_args.interleave_probs.split(",")
]
data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")]
dataset_list: List[DatasetAttr] = []
for name in dataset_names:
@@ -98,21 +92,13 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
if has_hf_url or has_ms_url:
if (use_modelscope() and has_ms_url) or (not has_hf_url):
dataset_attr = DatasetAttr(
"ms_hub", dataset_name=dataset_info[name]["ms_hub_url"]
)
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
else:
dataset_attr = DatasetAttr(
"hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]
)
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr(
"script", dataset_name=dataset_info[name]["script_url"]
)
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
else:
dataset_attr = DatasetAttr(
"file", dataset_name=dataset_info[name]["file_name"]
)
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
dataset_attr.set_attr("file_sha1", dataset_info[name])
dataset_attr.set_attr("subset", dataset_info[name])