modify code structure

This commit is contained in:
hiyouga
2023-08-02 23:17:36 +08:00
parent 1d8a1878ea
commit 08f180e788
25 changed files with 188 additions and 145 deletions

View File

@@ -3,20 +3,23 @@ import json
import gradio as gr
import matplotlib.figure
import matplotlib.pyplot as plt
from typing import Any, Dict, Generator, List, Tuple
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
from datetime import datetime
from llmtuner.extras.ploting import smooth
from llmtuner.tuner import get_infer_args, load_model_and_tokenizer
from llmtuner.tuner import export_model
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
from llmtuner.webui.locales import ALERTS
if TYPE_CHECKING:
from llmtuner.extras.callbacks import LogCallback
def format_info(log: str, tracker: dict) -> str:
def format_info(log: str, callback: "LogCallback") -> str:
info = log
if "current_steps" in tracker:
if callback.max_steps:
info += "Running **{:d}/{:d}**: {} < {}\n".format(
tracker["current_steps"], tracker["total_steps"], tracker["elapsed_time"], tracker["remaining_time"]
callback.cur_steps, callback.max_steps, callback.elapsed_time, callback.remaining_time
)
return info
@@ -87,7 +90,7 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
return fig
def export_model(
def save_model(
lang: str, model_name: str, checkpoints: List[str], finetuning_type: str, max_shard_size: int, save_dir: str
) -> Generator[str, None, None]:
if not model_name:
@@ -114,12 +117,10 @@ def export_model(
args = dict(
model_name_or_path=model_name_or_path,
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type
finetuning_type=finetuning_type,
output_dir=save_dir
)
yield ALERTS["info_exporting"][lang]
model_args, _, finetuning_args, _ = get_infer_args(args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.save_pretrained(save_dir, max_shard_size=str(max_shard_size)+"GB")
tokenizer.save_pretrained(save_dir)
export_model(args, max_shard_size="{}GB".format(max_shard_size))
yield ALERTS["info_exported"][lang]