Merge pull request #3291 from codemayq/main

support for previewing custom dataset in directory format

Former-commit-id: aa3206ec26236e8e176a3aa14229e0b0b31eb585
This commit is contained in:
hoshi-hiyouga 2024-04-16 18:12:09 +08:00 committed by GitHub
commit 7daaae2005

View File

@ -29,28 +29,41 @@ def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
except Exception: except Exception:
return gr.Button(interactive=False) return gr.Button(interactive=False)
if ( if len(dataset) == 0 or "file_name" not in dataset_info[dataset[0]]:
len(dataset) > 0 return gr.Button(interactive=False)
and "file_name" in dataset_info[dataset[0]]
and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])) local_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
): if (os.path.isfile(local_path)
or (os.path.isdir(local_path) and len(os.listdir(local_path)) != 0)):
return gr.Button(interactive=True) return gr.Button(interactive=True)
else: else:
return gr.Button(interactive=False) return gr.Button(interactive=False)
def load_single_data(data_file_path):
with open(os.path.join(data_file_path), "r", encoding="utf-8") as f:
if data_file_path.endswith(".json"):
data = json.load(f)
elif data_file_path.endswith(".jsonl"):
data = [json.loads(line) for line in f]
else:
data = [line for line in f] # noqa: C416
return data
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]: def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f) dataset_info = json.load(f)
data_file: str = dataset_info[dataset[0]]["file_name"] data_file: str = dataset_info[dataset[0]]["file_name"]
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f: local_path = os.path.join(dataset_dir, data_file)
if data_file.endswith(".json"): if os.path.isdir(local_path):
data = json.load(f) data = []
elif data_file.endswith(".jsonl"): for file_name in os.listdir(local_path):
data = [json.loads(line) for line in f] data.extend(load_single_data(os.path.join(local_path, file_name)))
else: else:
data = [line for line in f] # noqa: C416 data = load_single_data(local_path)
return len(data), data[PAGE_SIZE * page_index: PAGE_SIZE * (page_index + 1)], gr.Column(visible=True) return len(data), data[PAGE_SIZE * page_index: PAGE_SIZE * (page_index + 1)], gr.Column(visible=True)