mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 20:30:36 +08:00
update webui and add CLIs
This commit is contained in:
@@ -1,4 +0,0 @@
|
||||
from .interface import create_ui, create_web_demo
|
||||
|
||||
|
||||
__all__ = ["create_ui", "create_web_demo"]
|
||||
|
||||
@@ -4,6 +4,7 @@ from collections import defaultdict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME
|
||||
from yaml import safe_dump, safe_load
|
||||
|
||||
from ..extras.constants import (
|
||||
DATA_CONFIG,
|
||||
@@ -29,7 +30,7 @@ DEFAULT_CACHE_DIR = "cache"
|
||||
DEFAULT_CONFIG_DIR = "config"
|
||||
DEFAULT_DATA_DIR = "data"
|
||||
DEFAULT_SAVE_DIR = "saves"
|
||||
USER_CONFIG = "user.config"
|
||||
USER_CONFIG = "user_config.yaml"
|
||||
|
||||
|
||||
def get_save_dir(*args) -> os.PathLike:
|
||||
@@ -47,7 +48,7 @@ def get_save_path(config_path: str) -> os.PathLike:
|
||||
def load_config() -> Dict[str, Any]:
|
||||
try:
|
||||
with open(get_config_path(), "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
return safe_load(f)
|
||||
except Exception:
|
||||
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
|
||||
|
||||
@@ -60,13 +61,13 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
|
||||
user_config["last_model"] = model_name
|
||||
user_config["path_dict"][model_name] = model_path
|
||||
with open(get_config_path(), "w", encoding="utf-8") as f:
|
||||
json.dump(user_config, f, indent=2, ensure_ascii=False)
|
||||
safe_dump(user_config, f)
|
||||
|
||||
|
||||
def load_args(config_path: str) -> Optional[Dict[str, Any]]:
|
||||
try:
|
||||
with open(get_save_path(config_path), "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
return safe_load(f)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@@ -74,7 +75,7 @@ def load_args(config_path: str) -> Optional[Dict[str, Any]]:
|
||||
def save_args(config_path: str, config_dict: Dict[str, Any]) -> str:
|
||||
os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True)
|
||||
with open(get_save_path(config_path), "w", encoding="utf-8") as f:
|
||||
json.dump(config_dict, f, indent=2, ensure_ascii=False)
|
||||
safe_dump(config_dict, f)
|
||||
|
||||
return str(get_save_path(config_path))
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Dict, Generator, List
|
||||
|
||||
from ...extras.misc import torch_gc
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ...train import export_model
|
||||
from ...train.tuner import export_model
|
||||
from ..common import get_save_dir
|
||||
from ..locales import ALERTS
|
||||
|
||||
|
||||
@@ -245,7 +245,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
|
||||
with gr.Row():
|
||||
resume_btn = gr.Checkbox(visible=False, interactive=False)
|
||||
process_bar = gr.Slider(visible=False, interactive=False)
|
||||
progress_bar = gr.Slider(visible=False, interactive=False)
|
||||
|
||||
with gr.Row():
|
||||
output_box = gr.Markdown()
|
||||
@@ -263,14 +263,14 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
output_dir=output_dir,
|
||||
config_path=config_path,
|
||||
resume_btn=resume_btn,
|
||||
process_bar=process_bar,
|
||||
progress_bar=progress_bar,
|
||||
output_box=output_box,
|
||||
loss_viewer=loss_viewer,
|
||||
)
|
||||
)
|
||||
|
||||
input_elems.update({output_dir, config_path})
|
||||
output_elems = [output_box, process_bar, loss_viewer]
|
||||
output_elems = [output_box, progress_bar, loss_viewer]
|
||||
|
||||
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
|
||||
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
|
||||
|
||||
@@ -41,7 +41,7 @@ class Engine:
|
||||
init_dict["train.dataset"] = {"choices": list_dataset().choices}
|
||||
init_dict["eval.dataset"] = {"choices": list_dataset().choices}
|
||||
init_dict["train.output_dir"] = {"value": "train_{}".format(get_time())}
|
||||
init_dict["train.config_path"] = {"value": "{}.json".format(get_time())}
|
||||
init_dict["train.config_path"] = {"value": "{}.yaml".format(get_time())}
|
||||
init_dict["eval.output_dir"] = {"value": "eval_{}".format(get_time())}
|
||||
init_dict["infer.image_box"] = {"visible": False}
|
||||
|
||||
@@ -51,7 +51,7 @@ class Engine:
|
||||
|
||||
yield self._update_component(init_dict)
|
||||
|
||||
if self.runner.alive and not self.demo_mode and not self.pure_chat:
|
||||
if self.runner.running and not self.demo_mode and not self.pure_chat:
|
||||
yield {elem: elem.__class__(value=value) for elem, value in self.runner.running_data.items()}
|
||||
if self.runner.do_train:
|
||||
yield self._update_component({"train.resume_btn": {"value": True}})
|
||||
|
||||
@@ -68,5 +68,9 @@ def create_web_demo() -> gr.Blocks:
|
||||
return demo
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def run_web_ui():
|
||||
create_ui().queue().launch(server_name="0.0.0.0", server_port=None, share=False, inbrowser=True)
|
||||
|
||||
|
||||
def run_web_demo():
|
||||
create_web_demo().queue().launch(server_name="0.0.0.0", server_port=None, share=False, inbrowser=True)
|
||||
|
||||
@@ -1,22 +1,19 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator
|
||||
import signal
|
||||
from copy import deepcopy
|
||||
from subprocess import Popen, TimeoutExpired
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
|
||||
|
||||
import transformers
|
||||
import psutil
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.utils import is_torch_cuda_available
|
||||
|
||||
from ..extras.callbacks import LogCallback
|
||||
from ..extras.constants import TRAINING_STAGES
|
||||
from ..extras.logging import LoggerHandler
|
||||
from ..extras.misc import get_device_count, torch_gc
|
||||
from ..extras.packages import is_gradio_available
|
||||
from ..train import run_exp
|
||||
from .common import get_module, get_save_dir, load_args, load_config, save_args
|
||||
from .locales import ALERTS
|
||||
from .utils import gen_cmd, gen_plot, get_eval_results, update_process_bar
|
||||
from .utils import gen_cmd, get_eval_results, get_trainer_info, save_cmd
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
@@ -34,24 +31,18 @@ class Runner:
|
||||
self.manager = manager
|
||||
self.demo_mode = demo_mode
|
||||
""" Resume """
|
||||
self.thread: "Thread" = None
|
||||
self.trainer: Optional["Popen"] = None
|
||||
self.do_train = True
|
||||
self.running_data: Dict["Component", Any] = None
|
||||
""" State """
|
||||
self.aborted = False
|
||||
self.running = False
|
||||
""" Handler """
|
||||
self.logger_handler = LoggerHandler()
|
||||
self.logger_handler.setLevel(logging.INFO)
|
||||
logging.root.addHandler(self.logger_handler)
|
||||
transformers.logging.add_handler(self.logger_handler)
|
||||
|
||||
@property
|
||||
def alive(self) -> bool:
|
||||
return self.thread is not None
|
||||
|
||||
def set_abort(self) -> None:
|
||||
self.aborted = True
|
||||
if self.trainer is not None:
|
||||
for children in psutil.Process(self.trainer.pid).children(): # abort the child process
|
||||
os.kill(children.pid, signal.SIGABRT)
|
||||
|
||||
def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str:
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
@@ -85,13 +76,11 @@ class Runner:
|
||||
if not from_preview and not is_torch_cuda_available():
|
||||
gr.Warning(ALERTS["warn_no_cuda"][lang])
|
||||
|
||||
self.logger_handler.reset()
|
||||
self.trainer_callback = LogCallback(self)
|
||||
return ""
|
||||
|
||||
def _finalize(self, lang: str, finish_info: str) -> str:
|
||||
finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
|
||||
self.thread = None
|
||||
self.trainer = None
|
||||
self.aborted = False
|
||||
self.running = False
|
||||
self.running_data = None
|
||||
@@ -270,11 +259,12 @@ class Runner:
|
||||
gr.Warning(error)
|
||||
yield {output_box: error}
|
||||
else:
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
||||
self.do_train, self.running_data = do_train, data
|
||||
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
|
||||
self.thread.start()
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
env = deepcopy(os.environ)
|
||||
env["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
|
||||
env["LLAMABOARD_ENABLED"] = "1"
|
||||
self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True)
|
||||
yield from self.monitor()
|
||||
|
||||
def preview_train(self, data):
|
||||
@@ -291,9 +281,6 @@ class Runner:
|
||||
|
||||
def monitor(self):
|
||||
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
|
||||
self.aborted = False
|
||||
self.running = True
|
||||
|
||||
lang = get("top.lang")
|
||||
model_name = get("top.model_name")
|
||||
finetuning_type = get("top.finetuning_type")
|
||||
@@ -301,28 +288,31 @@ class Runner:
|
||||
output_path = get_save_dir(model_name, finetuning_type, output_dir)
|
||||
|
||||
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if self.do_train else "eval"))
|
||||
process_bar = self.manager.get_elem_by_id("{}.process_bar".format("train" if self.do_train else "eval"))
|
||||
progress_bar = self.manager.get_elem_by_id("{}.progress_bar".format("train" if self.do_train else "eval"))
|
||||
loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None
|
||||
|
||||
while self.thread is not None and self.thread.is_alive():
|
||||
while self.trainer is not None:
|
||||
if self.aborted:
|
||||
yield {
|
||||
output_box: ALERTS["info_aborting"][lang],
|
||||
process_bar: gr.Slider(visible=False),
|
||||
progress_bar: gr.Slider(visible=False),
|
||||
}
|
||||
else:
|
||||
running_log, running_progress, running_loss = get_trainer_info(output_path)
|
||||
return_dict = {
|
||||
output_box: self.logger_handler.log,
|
||||
process_bar: update_process_bar(self.trainer_callback),
|
||||
output_box: running_log,
|
||||
progress_bar: running_progress,
|
||||
}
|
||||
if self.do_train:
|
||||
plot = gen_plot(output_path)
|
||||
if plot is not None:
|
||||
return_dict[loss_viewer] = plot
|
||||
if self.do_train and running_loss is not None:
|
||||
return_dict[loss_viewer] = running_loss
|
||||
|
||||
yield return_dict
|
||||
|
||||
time.sleep(2)
|
||||
try:
|
||||
self.trainer.wait(2)
|
||||
self.trainer = None
|
||||
except TimeoutExpired:
|
||||
continue
|
||||
|
||||
if self.do_train:
|
||||
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)):
|
||||
@@ -337,16 +327,11 @@ class Runner:
|
||||
|
||||
return_dict = {
|
||||
output_box: self._finalize(lang, finish_info),
|
||||
process_bar: gr.Slider(visible=False),
|
||||
progress_bar: gr.Slider(visible=False),
|
||||
}
|
||||
if self.do_train:
|
||||
plot = gen_plot(output_path)
|
||||
if plot is not None:
|
||||
return_dict[loss_viewer] = plot
|
||||
|
||||
yield return_dict
|
||||
|
||||
def save_args(self, data):
|
||||
def save_args(self, data: dict):
|
||||
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||
error = self._initialize(data, do_train=True, from_preview=True)
|
||||
if error:
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from yaml import safe_dump
|
||||
|
||||
from ..extras.constants import RUNNING_LOG, TRAINER_CONFIG, TRAINER_LOG
|
||||
from ..extras.packages import is_gradio_available, is_matplotlib_available
|
||||
from ..extras.ploting import smooth
|
||||
from ..extras.ploting import gen_loss_plot
|
||||
from .locales import ALERTS
|
||||
|
||||
|
||||
@@ -12,30 +15,6 @@ if is_gradio_available():
|
||||
import gradio as gr
|
||||
|
||||
|
||||
if is_matplotlib_available():
|
||||
import matplotlib.figure
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..extras.callbacks import LogCallback
|
||||
|
||||
|
||||
def update_process_bar(callback: "LogCallback") -> "gr.Slider":
|
||||
if not callback.max_steps:
|
||||
return gr.Slider(visible=False)
|
||||
|
||||
percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
|
||||
label = "Running {:d}/{:d}: {} < {}".format(
|
||||
callback.cur_steps, callback.max_steps, callback.elapsed_time, callback.remaining_time
|
||||
)
|
||||
return gr.Slider(label=label, value=percentage, visible=True)
|
||||
|
||||
|
||||
def get_time() -> str:
|
||||
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
|
||||
|
||||
|
||||
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
|
||||
if finetuning_type != "lora":
|
||||
return gr.Dropdown(value="none", interactive=False)
|
||||
@@ -57,14 +36,19 @@ def check_json_schema(text: str, lang: str) -> None:
|
||||
gr.Warning(ALERTS["err_json_schema"][lang])
|
||||
|
||||
|
||||
def clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
no_skip_keys = ["packing"]
|
||||
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
|
||||
|
||||
|
||||
def gen_cmd(args: Dict[str, Any]) -> str:
|
||||
args.pop("disable_tqdm", None)
|
||||
args["plot_loss"] = args.get("do_train", None)
|
||||
current_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
|
||||
cmd_lines = ["CUDA_VISIBLE_DEVICES={} python src/train_bash.py ".format(current_devices)]
|
||||
for k, v in args.items():
|
||||
if v is not None and v is not False and v != "":
|
||||
cmd_lines.append(" --{} {} ".format(k, str(v)))
|
||||
for k, v in clean_cmd(args).items():
|
||||
cmd_lines.append(" --{} {} ".format(k, str(v)))
|
||||
|
||||
cmd_text = "\\\n".join(cmd_lines)
|
||||
cmd_text = "```bash\n{}\n```".format(cmd_text)
|
||||
return cmd_text
|
||||
@@ -76,29 +60,49 @@ def get_eval_results(path: os.PathLike) -> str:
|
||||
return "```json\n{}\n```\n".format(result)
|
||||
|
||||
|
||||
def gen_plot(output_path: str) -> Optional["matplotlib.figure.Figure"]:
|
||||
log_file = os.path.join(output_path, "trainer_log.jsonl")
|
||||
if not os.path.isfile(log_file) or not is_matplotlib_available():
|
||||
return
|
||||
def get_time() -> str:
|
||||
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
|
||||
|
||||
plt.close("all")
|
||||
plt.switch_backend("agg")
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
steps, losses = [], []
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
log_info: Dict[str, Any] = json.loads(line)
|
||||
if log_info.get("loss", None):
|
||||
steps.append(log_info["current_steps"])
|
||||
losses.append(log_info["loss"])
|
||||
|
||||
if len(losses) == 0:
|
||||
return
|
||||
def get_trainer_info(output_path: os.PathLike) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]:
|
||||
running_log = ""
|
||||
running_progress = gr.Slider(visible=False)
|
||||
running_loss = None
|
||||
|
||||
ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
|
||||
ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
|
||||
ax.legend()
|
||||
ax.set_xlabel("step")
|
||||
ax.set_ylabel("loss")
|
||||
return fig
|
||||
running_log_path = os.path.join(output_path, RUNNING_LOG)
|
||||
if os.path.isfile(running_log_path):
|
||||
with open(running_log_path, "r", encoding="utf-8") as f:
|
||||
running_log = f.read()
|
||||
|
||||
trainer_log_path = os.path.join(output_path, TRAINER_LOG)
|
||||
if os.path.isfile(trainer_log_path):
|
||||
trainer_log: List[Dict[str, Any]] = []
|
||||
with open(trainer_log_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
trainer_log.append(json.loads(line))
|
||||
|
||||
if len(trainer_log) != 0:
|
||||
latest_log = trainer_log[-1]
|
||||
percentage = latest_log["percentage"]
|
||||
label = "Running {:d}/{:d}: {} < {}".format(
|
||||
latest_log["current_steps"],
|
||||
latest_log["total_steps"],
|
||||
latest_log["elapsed_time"],
|
||||
latest_log["remaining_time"],
|
||||
)
|
||||
running_progress = gr.Slider(label=label, value=percentage, visible=True)
|
||||
|
||||
if is_matplotlib_available():
|
||||
running_loss = gr.Plot(gen_loss_plot(trainer_log))
|
||||
|
||||
return running_log, running_progress, running_loss
|
||||
|
||||
|
||||
def save_cmd(args: Dict[str, Any]) -> str:
|
||||
output_dir = args["output_dir"]
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
with open(os.path.join(output_dir, TRAINER_CONFIG), "w", encoding="utf-8") as f:
|
||||
safe_dump(clean_cmd(args), f)
|
||||
|
||||
return os.path.join(output_dir, TRAINER_CONFIG)
|
||||
|
||||
Reference in New Issue
Block a user