mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
support pagination in webui preview
This commit is contained in:
@@ -3,13 +3,11 @@ import json
|
||||
import gradio as gr
|
||||
import matplotlib.figure
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
from datetime import datetime
|
||||
|
||||
from llmtuner.extras.ploting import smooth
|
||||
from llmtuner.tuner import export_model
|
||||
from llmtuner.webui.common import get_save_dir, DATA_CONFIG
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
from llmtuner.webui.common import get_save_dir
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
@@ -33,37 +31,6 @@ def get_time() -> str:
|
||||
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
||||
|
||||
|
||||
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
|
||||
if (
|
||||
len(dataset) > 0
|
||||
and "file_name" in dataset_info[dataset[0]]
|
||||
and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]))
|
||||
):
|
||||
return gr.update(interactive=True)
|
||||
else:
|
||||
return gr.update(interactive=False)
|
||||
|
||||
|
||||
def get_preview(
|
||||
dataset_dir: str, dataset: list, start: Optional[int] = 0, end: Optional[int] = 2
|
||||
) -> Tuple[int, list, Dict[str, Any]]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
|
||||
data_file: str = dataset_info[dataset[0]]["file_name"]
|
||||
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
|
||||
if data_file.endswith(".json"):
|
||||
data = json.load(f)
|
||||
elif data_file.endswith(".jsonl"):
|
||||
data = [json.loads(line) for line in f]
|
||||
else:
|
||||
data = [line for line in f]
|
||||
return len(data), data[start:end], gr.update(visible=True)
|
||||
|
||||
|
||||
def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
||||
if finetuning_type != "lora":
|
||||
return gr.update(value="None", interactive=False)
|
||||
@@ -116,42 +83,3 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
|
||||
ax.set_xlabel("step")
|
||||
ax.set_ylabel("loss")
|
||||
return fig
|
||||
|
||||
|
||||
def save_model(
|
||||
lang: str,
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
checkpoints: List[str],
|
||||
finetuning_type: str,
|
||||
template: str,
|
||||
max_shard_size: int,
|
||||
export_dir: str
|
||||
) -> Generator[str, None, None]:
|
||||
if not model_name:
|
||||
yield ALERTS["err_no_model"][lang]
|
||||
return
|
||||
|
||||
if not model_path:
|
||||
yield ALERTS["err_no_path"][lang]
|
||||
return
|
||||
|
||||
if not checkpoints:
|
||||
yield ALERTS["err_no_checkpoint"][lang]
|
||||
return
|
||||
|
||||
if not export_dir:
|
||||
yield ALERTS["err_no_export_dir"][lang]
|
||||
return
|
||||
|
||||
args = dict(
|
||||
model_name_or_path=model_path,
|
||||
checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]),
|
||||
finetuning_type=finetuning_type,
|
||||
template=template,
|
||||
export_dir=export_dir
|
||||
)
|
||||
|
||||
yield ALERTS["info_exporting"][lang]
|
||||
export_model(args, max_shard_size="{}GB".format(max_shard_size))
|
||||
yield ALERTS["info_exported"][lang]
|
||||
|
||||
Reference in New Issue
Block a user