mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
modify code structure
This commit is contained in:
@@ -3,20 +3,23 @@ import json
|
||||
import gradio as gr
|
||||
import matplotlib.figure
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import Any, Dict, Generator, List, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
from llmtuner.extras.ploting import smooth
|
||||
from llmtuner.tuner import get_infer_args, load_model_and_tokenizer
|
||||
from llmtuner.tuner import export_model
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
|
||||
def format_info(log: str, tracker: dict) -> str:
|
||||
|
||||
def format_info(log: str, callback: "LogCallback") -> str:
|
||||
info = log
|
||||
if "current_steps" in tracker:
|
||||
if callback.max_steps:
|
||||
info += "Running **{:d}/{:d}**: {} < {}\n".format(
|
||||
tracker["current_steps"], tracker["total_steps"], tracker["elapsed_time"], tracker["remaining_time"]
|
||||
callback.cur_steps, callback.max_steps, callback.elapsed_time, callback.remaining_time
|
||||
)
|
||||
return info
|
||||
|
||||
@@ -87,7 +90,7 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
|
||||
return fig
|
||||
|
||||
|
||||
def export_model(
|
||||
def save_model(
|
||||
lang: str, model_name: str, checkpoints: List[str], finetuning_type: str, max_shard_size: int, save_dir: str
|
||||
) -> Generator[str, None, None]:
|
||||
if not model_name:
|
||||
@@ -114,12 +117,10 @@ def export_model(
|
||||
args = dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type
|
||||
finetuning_type=finetuning_type,
|
||||
output_dir=save_dir
|
||||
)
|
||||
|
||||
yield ALERTS["info_exporting"][lang]
|
||||
model_args, _, finetuning_args, _ = get_infer_args(args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
model.save_pretrained(save_dir, max_shard_size=str(max_shard_size)+"GB")
|
||||
tokenizer.save_pretrained(save_dir)
|
||||
export_model(args, max_shard_size="{}GB".format(max_shard_size))
|
||||
yield ALERTS["info_exported"][lang]
|
||||
|
||||
Reference in New Issue
Block a user