mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-06 05:32:50 +08:00
121 lines
4.3 KiB
Python
121 lines
4.3 KiB
Python
# Copyright 2024 the LlamaFactory team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import json
|
|
import os
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
|
|
|
from ...extras.constants import DATA_CONFIG
|
|
from ...extras.packages import is_gradio_available
|
|
|
|
|
|
if is_gradio_available():
|
|
import gradio as gr
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from gradio.components import Component
|
|
|
|
|
|
PAGE_SIZE = 2
|
|
|
|
|
|
def prev_page(page_index: int) -> int:
|
|
return page_index - 1 if page_index > 0 else page_index
|
|
|
|
|
|
def next_page(page_index: int, total_num: int) -> int:
|
|
return page_index + 1 if (page_index + 1) * PAGE_SIZE < total_num else page_index
|
|
|
|
|
|
def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
|
|
try:
|
|
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
|
|
dataset_info = json.load(f)
|
|
except Exception:
|
|
return gr.Button(interactive=False)
|
|
|
|
if len(dataset) == 0 or "file_name" not in dataset_info[dataset[0]]:
|
|
return gr.Button(interactive=False)
|
|
|
|
data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
|
|
if os.path.isfile(data_path) or (os.path.isdir(data_path) and os.listdir(data_path)):
|
|
return gr.Button(interactive=True)
|
|
else:
|
|
return gr.Button(interactive=False)
|
|
|
|
|
|
def _load_data_file(file_path: str) -> List[Any]:
|
|
with open(file_path, encoding="utf-8") as f:
|
|
if file_path.endswith(".json"):
|
|
return json.load(f)
|
|
elif file_path.endswith(".jsonl"):
|
|
return [json.loads(line) for line in f]
|
|
else:
|
|
return list(f)
|
|
|
|
|
|
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
|
|
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
|
|
dataset_info = json.load(f)
|
|
|
|
data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])
|
|
if os.path.isfile(data_path):
|
|
data = _load_data_file(data_path)
|
|
else:
|
|
data = []
|
|
for file_name in os.listdir(data_path):
|
|
data.extend(_load_data_file(os.path.join(data_path, file_name)))
|
|
|
|
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.Column(visible=True)
|
|
|
|
|
|
def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]:
|
|
data_preview_btn = gr.Button(interactive=False, scale=1)
|
|
with gr.Column(visible=False, elem_classes="modal-box") as preview_box:
|
|
with gr.Row():
|
|
preview_count = gr.Number(value=0, interactive=False, precision=0)
|
|
page_index = gr.Number(value=0, interactive=False, precision=0)
|
|
|
|
with gr.Row():
|
|
prev_btn = gr.Button()
|
|
next_btn = gr.Button()
|
|
close_btn = gr.Button()
|
|
|
|
with gr.Row():
|
|
preview_samples = gr.JSON()
|
|
|
|
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False).then(
|
|
lambda: 0, outputs=[page_index], queue=False
|
|
)
|
|
data_preview_btn.click(
|
|
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
|
|
)
|
|
prev_btn.click(prev_page, [page_index], [page_index], queue=False).then(
|
|
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
|
|
)
|
|
next_btn.click(next_page, [page_index, preview_count], [page_index], queue=False).then(
|
|
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
|
|
)
|
|
close_btn.click(lambda: gr.Column(visible=False), outputs=[preview_box], queue=False)
|
|
return dict(
|
|
data_preview_btn=data_preview_btn,
|
|
preview_count=preview_count,
|
|
page_index=page_index,
|
|
prev_btn=prev_btn,
|
|
next_btn=next_btn,
|
|
close_btn=close_btn,
|
|
preview_samples=preview_samples,
|
|
)
|