diff --git a/src/llamafactory/api/chat.py b/src/llamafactory/api/chat.py index 304c8cdc..c0d87196 100644 --- a/src/llamafactory/api/chat.py +++ b/src/llamafactory/api/chat.py @@ -132,7 +132,7 @@ def _process_request( if re.match(r"^data:video\/(mp4|mkv|avi|mov);base64,(.+)$", video_url): # base64 video video_stream = io.BytesIO(base64.b64decode(video_url.split(",", maxsplit=1)[1])) elif os.path.isfile(video_url): # local file - video_stream = open(video_url, "rb") + video_stream = video_url else: # web uri video_stream = requests.get(video_url, stream=True).raw @@ -143,7 +143,7 @@ def _process_request( if re.match(r"^data:audio\/(mpeg|mp3|wav|ogg);base64,(.+)$", audio_url): # base64 audio audio_stream = io.BytesIO(base64.b64decode(audio_url.split(",", maxsplit=1)[1])) elif os.path.isfile(audio_url): # local file - audio_stream = open(audio_url, "rb") + audio_stream = audio_url else: # web uri audio_stream = requests.get(audio_url, stream=True).raw diff --git a/src/llamafactory/extras/logging.py b/src/llamafactory/extras/logging.py index b078a966..f234a807 100644 --- a/src/llamafactory/extras/logging.py +++ b/src/llamafactory/extras/logging.py @@ -50,7 +50,7 @@ class LoggerHandler(logging.Handler): def _write_log(self, log_entry: str) -> None: with open(self.running_log, "a", encoding="utf-8") as f: - f.write(log_entry + "\n\n") + f.write(log_entry + "\n") def emit(self, record) -> None: if record.name == "httpx": diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 152961cc..7a86c23d 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -182,8 +182,22 @@ def get_logits_processor() -> "LogitsProcessorList": return logits_processor +def get_current_memory() -> tuple[int, int]: + r"""Get the available and total memory for the current device (in Bytes).""" + if is_torch_xpu_available(): + return torch.xpu.mem_get_info() + elif is_torch_npu_available(): + return torch.npu.mem_get_info() + elif is_torch_mps_available(): + return torch.mps.current_allocated_memory(), torch.mps.recommended_max_memory() + elif is_torch_cuda_available(): + return torch.cuda.mem_get_info() + else: + return 0, -1 + + def get_peak_memory() -> tuple[int, int]: - r"""Get the peak memory usage for the current device (in Bytes).""" + r"""Get the peak memory usage (allocated, reserved) for the current device (in Bytes).""" if is_torch_xpu_available(): return torch.xpu.max_memory_allocated(), torch.xpu.max_memory_reserved() elif is_torch_npu_available(): @@ -193,7 +207,7 @@ def get_peak_memory() -> tuple[int, int]: elif is_torch_cuda_available(): return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved() else: - return 0, 0 + return 0, -1 def has_tokenized_data(path: "os.PathLike") -> bool: diff --git a/src/llamafactory/webui/components/__init__.py b/src/llamafactory/webui/components/__init__.py index eb3c9d4c..e2c64ea7 100644 --- a/src/llamafactory/webui/components/__init__.py +++ b/src/llamafactory/webui/components/__init__.py @@ -15,6 +15,7 @@ from .chatbot import create_chat_box from .eval import create_eval_tab from .export import create_export_tab +from .footer import create_footer from .infer import create_infer_tab from .top import create_top from .train import create_train_tab @@ -24,6 +25,7 @@ __all__ = [ "create_chat_box", "create_eval_tab", "create_export_tab", + "create_footer", "create_infer_tab", "create_top", "create_train_tab", diff --git a/src/llamafactory/webui/components/footer.py b/src/llamafactory/webui/components/footer.py new file mode 100644 index 00000000..f1b73abb --- /dev/null +++ b/src/llamafactory/webui/components/footer.py @@ -0,0 +1,42 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...extras.misc import get_current_memory +from ...extras.packages import is_gradio_available + + +if is_gradio_available(): + import gradio as gr + + +if TYPE_CHECKING: + from gradio.components import Component + + +def get_device_memory() -> "gr.Slider": + free, total = get_current_memory() + used = round((total - free) / (1024**3), 2) + total = round(total / (1024**3), 2) + return gr.Slider(minimum=0, maximum=total, value=used, step=0.01) + + +def create_footer() -> dict[str, "Component"]: + with gr.Row(): + device_memory = gr.Slider(interactive=False) + timer = gr.Timer(value=5) + + timer.tick(get_device_memory, outputs=[device_memory], queue=False) + return dict(device_memory=device_memory) diff --git a/src/llamafactory/webui/control.py b/src/llamafactory/webui/control.py index 1103f27a..bbd90b66 100644 --- a/src/llamafactory/webui/control.py +++ b/src/llamafactory/webui/control.py @@ -112,7 +112,7 @@ def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> tup running_log_path = os.path.join(output_path, RUNNING_LOG) if os.path.isfile(running_log_path): with open(running_log_path, encoding="utf-8") as f: - running_log = f.read()[-20000:] # avoid lengthy log + running_log = "```\n" + f.read()[-20000:] + "\n```\n" # avoid lengthy log trainer_log_path = os.path.join(output_path, TRAINER_LOG) if os.path.isfile(trainer_log_path): diff --git a/src/llamafactory/webui/interface.py b/src/llamafactory/webui/interface.py index 691a88a3..b52cbf16 100644 --- a/src/llamafactory/webui/interface.py +++ b/src/llamafactory/webui/interface.py @@ -22,6 +22,7 @@ from .components import ( create_chat_box, create_eval_tab, create_export_tab, + create_footer, create_infer_tab, create_top, create_train_tab, @@ -63,6 +64,7 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks": with gr.Tab("Export"): engine.manager.add_elems("export", create_export_tab(engine)) + engine.manager.add_elems("footer", create_footer()) demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None) lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False) lang.input(save_config, inputs=[lang], queue=False) diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index 17ffb0a4..b028abd8 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -2849,6 +2849,28 @@ LOCALES = { "value": "エクスポート", }, }, + "device_memory": { + "en": { + "label": "Device memory", + "info": "Current memory usage of the device (GB).", + }, + "ru": { + "label": "Память устройства", + "info": "Текущая память на устройстве (GB).", + }, + "zh": { + "label": "设备显存", + "info": "当前设备的显存(GB)。", + }, + "ko": { + "label": "디바이스 메모리", + "info": "지금 사용 중인 기기 메모리 (GB).", + }, + "ja": { + "label": "デバイスメモリ", + "info": "現在のデバイスのメモリ(GB)。", + }, + }, } diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index e0becdcf..a781076b 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -16,14 +16,13 @@ import json import os from collections.abc import Generator from copy import deepcopy -from subprocess import Popen, TimeoutExpired +from subprocess import PIPE, Popen, TimeoutExpired from typing import TYPE_CHECKING, Any, Optional -from transformers.trainer import TRAINING_ARGS_NAME from transformers.utils import is_torch_npu_available from ..extras.constants import LLAMABOARD_CONFIG, MULTIMODAL_SUPPORTED_MODELS, PEFT_METHODS, TRAINING_STAGES -from ..extras.misc import is_accelerator_available, torch_gc, use_ray +from ..extras.misc import is_accelerator_available, torch_gc from ..extras.packages import is_gradio_available from .common import ( DEFAULT_CACHE_DIR, @@ -114,7 +113,7 @@ class Runner: return "" - def _finalize(self, lang: str, finish_info: str) -> str: + def _finalize(self, lang: str, finish_info: str) -> None: r"""Clean the cached memory and resets the runner.""" finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info gr.Info(finish_info) @@ -123,7 +122,6 @@ class Runner: self.running = False self.running_data = None torch_gc() - return finish_info def _parse_train_args(self, data: dict["Component", Any]) -> dict[str, Any]: r"""Build and validate the training arguments.""" @@ -314,11 +312,13 @@ class Runner: max_samples=int(get("eval.max_samples")), per_device_eval_batch_size=get("eval.batch_size"), predict_with_generate=True, + report_to="none", max_new_tokens=get("eval.max_new_tokens"), top_p=get("eval.top_p"), temperature=get("eval.temperature"), output_dir=get_save_dir(model_name, finetuning_type, get("eval.output_dir")), trust_remote_code=True, + ddp_timeout=180000000, ) if get("eval.predict"): @@ -375,7 +375,7 @@ class Runner: env["FORCE_TORCHRUN"] = "1" # NOTE: DO NOT USE shell=True to avoid security risk - self.trainer = Popen(["llamafactory-cli", "train", save_cmd(args)], env=env) + self.trainer = Popen(["llamafactory-cli", "train", save_cmd(args)], env=env, stderr=PIPE, text=True) yield from self.monitor() def _build_config_dict(self, data: dict["Component", Any]) -> dict[str, Any]: @@ -417,7 +417,8 @@ class Runner: swanlab_link = self.manager.get_elem_by_id("train.swanlab_link") if self.do_train else None running_log = "" - while self.trainer is not None: + return_code = -1 + while return_code == -1: if self.aborted: yield { output_box: ALERTS["info_aborting"][lang], @@ -436,27 +437,26 @@ class Runner: return_dict[swanlab_link] = running_info["swanlab_link"] yield return_dict + try: - self.trainer.wait(2) - self.trainer = None + stderr = self.trainer.communicate(timeout=2)[1] + return_code = self.trainer.returncode except TimeoutExpired: continue - if self.do_train: - if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)) or use_ray(): - finish_info = ALERTS["info_finished"][lang] + if return_code == 0: + finish_info = ALERTS["info_finished"][lang] + if self.do_train: + finish_log = ALERTS["info_finished"][lang] + "\n\n" + running_log else: - finish_info = ALERTS["err_failed"][lang] + finish_log = load_eval_results(os.path.join(output_path, "all_results.json")) + "\n\n" + running_log else: - if os.path.exists(os.path.join(output_path, "all_results.json")) or use_ray(): - finish_info = load_eval_results(os.path.join(output_path, "all_results.json")) - else: - finish_info = ALERTS["err_failed"][lang] + print(stderr) + finish_info = ALERTS["err_failed"][lang] + finish_log = ALERTS["err_failed"][lang] + f" Exit code: {return_code}\n\n```\n{stderr}\n```\n" - return_dict = { - output_box: self._finalize(lang, finish_info) + "\n\n" + running_log, - progress_bar: gr.Slider(visible=False), - } + self._finalize(lang, finish_info) + return_dict = {output_box: finish_log, progress_bar: gr.Slider(visible=False)} yield return_dict def save_args(self, data):