[webui] upgrade webui and fix api (#8460)

This commit is contained in:
Yaowei Zheng 2025-06-25 21:59:58 +08:00 committed by GitHub
parent b10333dafb
commit ed57b7ba2a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 109 additions and 27 deletions

View File

@ -132,7 +132,7 @@ def _process_request(
if re.match(r"^data:video\/(mp4|mkv|avi|mov);base64,(.+)$", video_url): # base64 video if re.match(r"^data:video\/(mp4|mkv|avi|mov);base64,(.+)$", video_url): # base64 video
video_stream = io.BytesIO(base64.b64decode(video_url.split(",", maxsplit=1)[1])) video_stream = io.BytesIO(base64.b64decode(video_url.split(",", maxsplit=1)[1]))
elif os.path.isfile(video_url): # local file elif os.path.isfile(video_url): # local file
video_stream = open(video_url, "rb") video_stream = video_url
else: # web uri else: # web uri
video_stream = requests.get(video_url, stream=True).raw video_stream = requests.get(video_url, stream=True).raw
@ -143,7 +143,7 @@ def _process_request(
if re.match(r"^data:audio\/(mpeg|mp3|wav|ogg);base64,(.+)$", audio_url): # base64 audio if re.match(r"^data:audio\/(mpeg|mp3|wav|ogg);base64,(.+)$", audio_url): # base64 audio
audio_stream = io.BytesIO(base64.b64decode(audio_url.split(",", maxsplit=1)[1])) audio_stream = io.BytesIO(base64.b64decode(audio_url.split(",", maxsplit=1)[1]))
elif os.path.isfile(audio_url): # local file elif os.path.isfile(audio_url): # local file
audio_stream = open(audio_url, "rb") audio_stream = audio_url
else: # web uri else: # web uri
audio_stream = requests.get(audio_url, stream=True).raw audio_stream = requests.get(audio_url, stream=True).raw

View File

@ -50,7 +50,7 @@ class LoggerHandler(logging.Handler):
def _write_log(self, log_entry: str) -> None: def _write_log(self, log_entry: str) -> None:
with open(self.running_log, "a", encoding="utf-8") as f: with open(self.running_log, "a", encoding="utf-8") as f:
f.write(log_entry + "\n\n") f.write(log_entry + "\n")
def emit(self, record) -> None: def emit(self, record) -> None:
if record.name == "httpx": if record.name == "httpx":

View File

@ -182,8 +182,22 @@ def get_logits_processor() -> "LogitsProcessorList":
return logits_processor return logits_processor
def get_current_memory() -> tuple[int, int]:
r"""Get the available and total memory for the current device (in Bytes)."""
if is_torch_xpu_available():
return torch.xpu.mem_get_info()
elif is_torch_npu_available():
return torch.npu.mem_get_info()
elif is_torch_mps_available():
return torch.mps.current_allocated_memory(), torch.mps.recommended_max_memory()
elif is_torch_cuda_available():
return torch.cuda.mem_get_info()
else:
return 0, -1
def get_peak_memory() -> tuple[int, int]: def get_peak_memory() -> tuple[int, int]:
r"""Get the peak memory usage for the current device (in Bytes).""" r"""Get the peak memory usage (allocated, reserved) for the current device (in Bytes)."""
if is_torch_xpu_available(): if is_torch_xpu_available():
return torch.xpu.max_memory_allocated(), torch.xpu.max_memory_reserved() return torch.xpu.max_memory_allocated(), torch.xpu.max_memory_reserved()
elif is_torch_npu_available(): elif is_torch_npu_available():
@ -193,7 +207,7 @@ def get_peak_memory() -> tuple[int, int]:
elif is_torch_cuda_available(): elif is_torch_cuda_available():
return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved() return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
else: else:
return 0, 0 return 0, -1
def has_tokenized_data(path: "os.PathLike") -> bool: def has_tokenized_data(path: "os.PathLike") -> bool:

View File

@ -15,6 +15,7 @@
from .chatbot import create_chat_box from .chatbot import create_chat_box
from .eval import create_eval_tab from .eval import create_eval_tab
from .export import create_export_tab from .export import create_export_tab
from .footer import create_footer
from .infer import create_infer_tab from .infer import create_infer_tab
from .top import create_top from .top import create_top
from .train import create_train_tab from .train import create_train_tab
@ -24,6 +25,7 @@ __all__ = [
"create_chat_box", "create_chat_box",
"create_eval_tab", "create_eval_tab",
"create_export_tab", "create_export_tab",
"create_footer",
"create_infer_tab", "create_infer_tab",
"create_top", "create_top",
"create_train_tab", "create_train_tab",

View File

@ -0,0 +1,42 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...extras.misc import get_current_memory
from ...extras.packages import is_gradio_available
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
def get_device_memory() -> "gr.Slider":
free, total = get_current_memory()
used = round((total - free) / (1024**3), 2)
total = round(total / (1024**3), 2)
return gr.Slider(minimum=0, maximum=total, value=used, step=0.01)
def create_footer() -> dict[str, "Component"]:
with gr.Row():
device_memory = gr.Slider(interactive=False)
timer = gr.Timer(value=5)
timer.tick(get_device_memory, outputs=[device_memory], queue=False)
return dict(device_memory=device_memory)

View File

@ -112,7 +112,7 @@ def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> tup
running_log_path = os.path.join(output_path, RUNNING_LOG) running_log_path = os.path.join(output_path, RUNNING_LOG)
if os.path.isfile(running_log_path): if os.path.isfile(running_log_path):
with open(running_log_path, encoding="utf-8") as f: with open(running_log_path, encoding="utf-8") as f:
running_log = f.read()[-20000:] # avoid lengthy log running_log = "```\n" + f.read()[-20000:] + "\n```\n" # avoid lengthy log
trainer_log_path = os.path.join(output_path, TRAINER_LOG) trainer_log_path = os.path.join(output_path, TRAINER_LOG)
if os.path.isfile(trainer_log_path): if os.path.isfile(trainer_log_path):

View File

@ -22,6 +22,7 @@ from .components import (
create_chat_box, create_chat_box,
create_eval_tab, create_eval_tab,
create_export_tab, create_export_tab,
create_footer,
create_infer_tab, create_infer_tab,
create_top, create_top,
create_train_tab, create_train_tab,
@ -63,6 +64,7 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks":
with gr.Tab("Export"): with gr.Tab("Export"):
engine.manager.add_elems("export", create_export_tab(engine)) engine.manager.add_elems("export", create_export_tab(engine))
engine.manager.add_elems("footer", create_footer())
demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None) demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None)
lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False) lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False)
lang.input(save_config, inputs=[lang], queue=False) lang.input(save_config, inputs=[lang], queue=False)

View File

@ -2849,6 +2849,28 @@ LOCALES = {
"value": "エクスポート", "value": "エクスポート",
}, },
}, },
"device_memory": {
"en": {
"label": "Device memory",
"info": "Current memory usage of the device (GB).",
},
"ru": {
"label": "Память устройства",
"info": "Текущая память на устройстве (GB).",
},
"zh": {
"label": "设备显存",
"info": "当前设备的显存GB",
},
"ko": {
"label": "디바이스 메모리",
"info": "지금 사용 중인 기기 메모리 (GB).",
},
"ja": {
"label": "デバイスメモリ",
"info": "現在のデバイスのメモリGB",
},
},
} }

View File

@ -16,14 +16,13 @@ import json
import os import os
from collections.abc import Generator from collections.abc import Generator
from copy import deepcopy from copy import deepcopy
from subprocess import Popen, TimeoutExpired from subprocess import PIPE, Popen, TimeoutExpired
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.utils import is_torch_npu_available from transformers.utils import is_torch_npu_available
from ..extras.constants import LLAMABOARD_CONFIG, MULTIMODAL_SUPPORTED_MODELS, PEFT_METHODS, TRAINING_STAGES from ..extras.constants import LLAMABOARD_CONFIG, MULTIMODAL_SUPPORTED_MODELS, PEFT_METHODS, TRAINING_STAGES
from ..extras.misc import is_accelerator_available, torch_gc, use_ray from ..extras.misc import is_accelerator_available, torch_gc
from ..extras.packages import is_gradio_available from ..extras.packages import is_gradio_available
from .common import ( from .common import (
DEFAULT_CACHE_DIR, DEFAULT_CACHE_DIR,
@ -114,7 +113,7 @@ class Runner:
return "" return ""
def _finalize(self, lang: str, finish_info: str) -> str: def _finalize(self, lang: str, finish_info: str) -> None:
r"""Clean the cached memory and resets the runner.""" r"""Clean the cached memory and resets the runner."""
finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
gr.Info(finish_info) gr.Info(finish_info)
@ -123,7 +122,6 @@ class Runner:
self.running = False self.running = False
self.running_data = None self.running_data = None
torch_gc() torch_gc()
return finish_info
def _parse_train_args(self, data: dict["Component", Any]) -> dict[str, Any]: def _parse_train_args(self, data: dict["Component", Any]) -> dict[str, Any]:
r"""Build and validate the training arguments.""" r"""Build and validate the training arguments."""
@ -314,11 +312,13 @@ class Runner:
max_samples=int(get("eval.max_samples")), max_samples=int(get("eval.max_samples")),
per_device_eval_batch_size=get("eval.batch_size"), per_device_eval_batch_size=get("eval.batch_size"),
predict_with_generate=True, predict_with_generate=True,
report_to="none",
max_new_tokens=get("eval.max_new_tokens"), max_new_tokens=get("eval.max_new_tokens"),
top_p=get("eval.top_p"), top_p=get("eval.top_p"),
temperature=get("eval.temperature"), temperature=get("eval.temperature"),
output_dir=get_save_dir(model_name, finetuning_type, get("eval.output_dir")), output_dir=get_save_dir(model_name, finetuning_type, get("eval.output_dir")),
trust_remote_code=True, trust_remote_code=True,
ddp_timeout=180000000,
) )
if get("eval.predict"): if get("eval.predict"):
@ -375,7 +375,7 @@ class Runner:
env["FORCE_TORCHRUN"] = "1" env["FORCE_TORCHRUN"] = "1"
# NOTE: DO NOT USE shell=True to avoid security risk # NOTE: DO NOT USE shell=True to avoid security risk
self.trainer = Popen(["llamafactory-cli", "train", save_cmd(args)], env=env) self.trainer = Popen(["llamafactory-cli", "train", save_cmd(args)], env=env, stderr=PIPE, text=True)
yield from self.monitor() yield from self.monitor()
def _build_config_dict(self, data: dict["Component", Any]) -> dict[str, Any]: def _build_config_dict(self, data: dict["Component", Any]) -> dict[str, Any]:
@ -417,7 +417,8 @@ class Runner:
swanlab_link = self.manager.get_elem_by_id("train.swanlab_link") if self.do_train else None swanlab_link = self.manager.get_elem_by_id("train.swanlab_link") if self.do_train else None
running_log = "" running_log = ""
while self.trainer is not None: return_code = -1
while return_code == -1:
if self.aborted: if self.aborted:
yield { yield {
output_box: ALERTS["info_aborting"][lang], output_box: ALERTS["info_aborting"][lang],
@ -436,27 +437,26 @@ class Runner:
return_dict[swanlab_link] = running_info["swanlab_link"] return_dict[swanlab_link] = running_info["swanlab_link"]
yield return_dict yield return_dict
try: try:
self.trainer.wait(2) stderr = self.trainer.communicate(timeout=2)[1]
self.trainer = None return_code = self.trainer.returncode
except TimeoutExpired: except TimeoutExpired:
continue continue
if self.do_train: if return_code == 0:
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)) or use_ray(): finish_info = ALERTS["info_finished"][lang]
finish_info = ALERTS["info_finished"][lang] if self.do_train:
finish_log = ALERTS["info_finished"][lang] + "\n\n" + running_log
else: else:
finish_info = ALERTS["err_failed"][lang] finish_log = load_eval_results(os.path.join(output_path, "all_results.json")) + "\n\n" + running_log
else: else:
if os.path.exists(os.path.join(output_path, "all_results.json")) or use_ray(): print(stderr)
finish_info = load_eval_results(os.path.join(output_path, "all_results.json")) finish_info = ALERTS["err_failed"][lang]
else: finish_log = ALERTS["err_failed"][lang] + f" Exit code: {return_code}\n\n```\n{stderr}\n```\n"
finish_info = ALERTS["err_failed"][lang]
return_dict = { self._finalize(lang, finish_info)
output_box: self._finalize(lang, finish_info) + "\n\n" + running_log, return_dict = {output_box: finish_log, progress_bar: gr.Slider(visible=False)}
progress_bar: gr.Slider(visible=False),
}
yield return_dict yield return_dict
def save_args(self, data): def save_args(self, data):