mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
90 lines
2.8 KiB
Python
90 lines
2.8 KiB
Python
import os
|
|
import json
|
|
import gradio as gr
|
|
from typing import TYPE_CHECKING, Any, Dict
|
|
from datetime import datetime
|
|
|
|
from llmtuner.extras.packages import is_matplotlib_available
|
|
from llmtuner.extras.ploting import smooth
|
|
from llmtuner.webui.common import get_save_dir
|
|
|
|
if TYPE_CHECKING:
|
|
from llmtuner.extras.callbacks import LogCallback
|
|
|
|
if is_matplotlib_available():
|
|
import matplotlib.figure
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
|
|
if not callback.max_steps:
|
|
return gr.update(visible=False)
|
|
|
|
percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
|
|
label = "Running {:d}/{:d}: {} < {}".format(
|
|
callback.cur_steps,
|
|
callback.max_steps,
|
|
callback.elapsed_time,
|
|
callback.remaining_time
|
|
)
|
|
return gr.update(label=label, value=percentage, visible=True)
|
|
|
|
|
|
def get_time() -> str:
|
|
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
|
|
|
|
|
def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
|
if finetuning_type != "lora":
|
|
return gr.update(value="None", interactive=False)
|
|
else:
|
|
return gr.update(interactive=True)
|
|
|
|
|
|
def gen_cmd(args: Dict[str, Any]) -> str:
|
|
args.pop("disable_tqdm", None)
|
|
args["plot_loss"] = args.get("do_train", None)
|
|
cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES') or "0"
|
|
cmd_lines = [f"CUDA_VISIBLE_DEVICES={cuda_visible_devices} python src/train_bash.py "]
|
|
for k, v in args.items():
|
|
if v is not None and v != "":
|
|
cmd_lines.append(" --{} {} ".format(k, str(v)))
|
|
cmd_text = "\\\n".join(cmd_lines)
|
|
cmd_text = "```bash\n{}\n```".format(cmd_text)
|
|
return cmd_text
|
|
|
|
|
|
def get_eval_results(path: os.PathLike) -> str:
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
result = json.dumps(json.load(f), indent=4)
|
|
return "```json\n{}\n```\n".format(result)
|
|
|
|
|
|
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> "matplotlib.figure.Figure":
|
|
if not base_model:
|
|
return
|
|
log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl")
|
|
if not os.path.isfile(log_file):
|
|
return
|
|
|
|
plt.close("all")
|
|
fig = plt.figure()
|
|
ax = fig.add_subplot(111)
|
|
steps, losses = [], []
|
|
with open(log_file, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
log_info = json.loads(line)
|
|
if log_info.get("loss", None):
|
|
steps.append(log_info["current_steps"])
|
|
losses.append(log_info["loss"])
|
|
|
|
if len(losses) == 0:
|
|
return None
|
|
|
|
ax.plot(steps, losses, alpha=0.4, label="original")
|
|
ax.plot(steps, smooth(losses), label="smoothed")
|
|
ax.legend()
|
|
ax.set_xlabel("step")
|
|
ax.set_ylabel("loss")
|
|
return fig
|