mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[webui] upgrade webui and fix api (#8460)
This commit is contained in:
		
							parent
							
								
									f276b9a963
								
							
						
					
					
						commit
						4407231a3b
					
				@ -132,7 +132,7 @@ def _process_request(
 | 
			
		||||
                    if re.match(r"^data:video\/(mp4|mkv|avi|mov);base64,(.+)$", video_url):  # base64 video
 | 
			
		||||
                        video_stream = io.BytesIO(base64.b64decode(video_url.split(",", maxsplit=1)[1]))
 | 
			
		||||
                    elif os.path.isfile(video_url):  # local file
 | 
			
		||||
                        video_stream = open(video_url, "rb")
 | 
			
		||||
                        video_stream = video_url
 | 
			
		||||
                    else:  # web uri
 | 
			
		||||
                        video_stream = requests.get(video_url, stream=True).raw
 | 
			
		||||
 | 
			
		||||
@ -143,7 +143,7 @@ def _process_request(
 | 
			
		||||
                    if re.match(r"^data:audio\/(mpeg|mp3|wav|ogg);base64,(.+)$", audio_url):  # base64 audio
 | 
			
		||||
                        audio_stream = io.BytesIO(base64.b64decode(audio_url.split(",", maxsplit=1)[1]))
 | 
			
		||||
                    elif os.path.isfile(audio_url):  # local file
 | 
			
		||||
                        audio_stream = open(audio_url, "rb")
 | 
			
		||||
                        audio_stream = audio_url
 | 
			
		||||
                    else:  # web uri
 | 
			
		||||
                        audio_stream = requests.get(audio_url, stream=True).raw
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -50,7 +50,7 @@ class LoggerHandler(logging.Handler):
 | 
			
		||||
 | 
			
		||||
    def _write_log(self, log_entry: str) -> None:
 | 
			
		||||
        with open(self.running_log, "a", encoding="utf-8") as f:
 | 
			
		||||
            f.write(log_entry + "\n\n")
 | 
			
		||||
            f.write(log_entry + "\n")
 | 
			
		||||
 | 
			
		||||
    def emit(self, record) -> None:
 | 
			
		||||
        if record.name == "httpx":
 | 
			
		||||
 | 
			
		||||
@ -182,8 +182,22 @@ def get_logits_processor() -> "LogitsProcessorList":
 | 
			
		||||
    return logits_processor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_current_memory() -> tuple[int, int]:
 | 
			
		||||
    r"""Get the available and total memory for the current device (in Bytes)."""
 | 
			
		||||
    if is_torch_xpu_available():
 | 
			
		||||
        return torch.xpu.mem_get_info()
 | 
			
		||||
    elif is_torch_npu_available():
 | 
			
		||||
        return torch.npu.mem_get_info()
 | 
			
		||||
    elif is_torch_mps_available():
 | 
			
		||||
        return torch.mps.current_allocated_memory(), torch.mps.recommended_max_memory()
 | 
			
		||||
    elif is_torch_cuda_available():
 | 
			
		||||
        return torch.cuda.mem_get_info()
 | 
			
		||||
    else:
 | 
			
		||||
        return 0, -1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_peak_memory() -> tuple[int, int]:
 | 
			
		||||
    r"""Get the peak memory usage for the current device (in Bytes)."""
 | 
			
		||||
    r"""Get the peak memory usage (allocated, reserved) for the current device (in Bytes)."""
 | 
			
		||||
    if is_torch_xpu_available():
 | 
			
		||||
        return torch.xpu.max_memory_allocated(), torch.xpu.max_memory_reserved()
 | 
			
		||||
    elif is_torch_npu_available():
 | 
			
		||||
@ -193,7 +207,7 @@ def get_peak_memory() -> tuple[int, int]:
 | 
			
		||||
    elif is_torch_cuda_available():
 | 
			
		||||
        return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
 | 
			
		||||
    else:
 | 
			
		||||
        return 0, 0
 | 
			
		||||
        return 0, -1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def has_tokenized_data(path: "os.PathLike") -> bool:
 | 
			
		||||
 | 
			
		||||
@ -15,6 +15,7 @@
 | 
			
		||||
from .chatbot import create_chat_box
 | 
			
		||||
from .eval import create_eval_tab
 | 
			
		||||
from .export import create_export_tab
 | 
			
		||||
from .footer import create_footer
 | 
			
		||||
from .infer import create_infer_tab
 | 
			
		||||
from .top import create_top
 | 
			
		||||
from .train import create_train_tab
 | 
			
		||||
@ -24,6 +25,7 @@ __all__ = [
 | 
			
		||||
    "create_chat_box",
 | 
			
		||||
    "create_eval_tab",
 | 
			
		||||
    "create_export_tab",
 | 
			
		||||
    "create_footer",
 | 
			
		||||
    "create_infer_tab",
 | 
			
		||||
    "create_top",
 | 
			
		||||
    "create_train_tab",
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										42
									
								
								src/llamafactory/webui/components/footer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								src/llamafactory/webui/components/footer.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,42 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from typing import TYPE_CHECKING
 | 
			
		||||
 | 
			
		||||
from ...extras.misc import get_current_memory
 | 
			
		||||
from ...extras.packages import is_gradio_available
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_gradio_available():
 | 
			
		||||
    import gradio as gr
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from gradio.components import Component
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_device_memory() -> "gr.Slider":
 | 
			
		||||
    free, total = get_current_memory()
 | 
			
		||||
    used = round((total - free) / (1024**3), 2)
 | 
			
		||||
    total = round(total / (1024**3), 2)
 | 
			
		||||
    return gr.Slider(minimum=0, maximum=total, value=used, step=0.01)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_footer() -> dict[str, "Component"]:
 | 
			
		||||
    with gr.Row():
 | 
			
		||||
        device_memory = gr.Slider(interactive=False)
 | 
			
		||||
        timer = gr.Timer(value=5)
 | 
			
		||||
 | 
			
		||||
    timer.tick(get_device_memory, outputs=[device_memory], queue=False)
 | 
			
		||||
    return dict(device_memory=device_memory)
 | 
			
		||||
@ -112,7 +112,7 @@ def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> tup
 | 
			
		||||
    running_log_path = os.path.join(output_path, RUNNING_LOG)
 | 
			
		||||
    if os.path.isfile(running_log_path):
 | 
			
		||||
        with open(running_log_path, encoding="utf-8") as f:
 | 
			
		||||
            running_log = f.read()[-20000:]  # avoid lengthy log
 | 
			
		||||
            running_log = "```\n" + f.read()[-20000:] + "\n```\n"  # avoid lengthy log
 | 
			
		||||
 | 
			
		||||
    trainer_log_path = os.path.join(output_path, TRAINER_LOG)
 | 
			
		||||
    if os.path.isfile(trainer_log_path):
 | 
			
		||||
 | 
			
		||||
@ -22,6 +22,7 @@ from .components import (
 | 
			
		||||
    create_chat_box,
 | 
			
		||||
    create_eval_tab,
 | 
			
		||||
    create_export_tab,
 | 
			
		||||
    create_footer,
 | 
			
		||||
    create_infer_tab,
 | 
			
		||||
    create_top,
 | 
			
		||||
    create_train_tab,
 | 
			
		||||
@ -63,6 +64,7 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks":
 | 
			
		||||
            with gr.Tab("Export"):
 | 
			
		||||
                engine.manager.add_elems("export", create_export_tab(engine))
 | 
			
		||||
 | 
			
		||||
        engine.manager.add_elems("footer", create_footer())
 | 
			
		||||
        demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None)
 | 
			
		||||
        lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False)
 | 
			
		||||
        lang.input(save_config, inputs=[lang], queue=False)
 | 
			
		||||
 | 
			
		||||
@ -2849,6 +2849,28 @@ LOCALES = {
 | 
			
		||||
            "value": "エクスポート",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "device_memory": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Device memory",
 | 
			
		||||
            "info": "Current memory usage of the device (GB).",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Память устройства",
 | 
			
		||||
            "info": "Текущая память на устройстве (GB).",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "设备显存",
 | 
			
		||||
            "info": "当前设备的显存(GB)。",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "label": "디바이스 메모리",
 | 
			
		||||
            "info": "지금 사용 중인 기기 메모리 (GB).",
 | 
			
		||||
        },
 | 
			
		||||
        "ja": {
 | 
			
		||||
            "label": "デバイスメモリ",
 | 
			
		||||
            "info": "現在のデバイスのメモリ(GB)。",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -16,14 +16,13 @@ import json
 | 
			
		||||
import os
 | 
			
		||||
from collections.abc import Generator
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from subprocess import Popen, TimeoutExpired
 | 
			
		||||
from subprocess import PIPE, Popen, TimeoutExpired
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Optional
 | 
			
		||||
 | 
			
		||||
from transformers.trainer import TRAINING_ARGS_NAME
 | 
			
		||||
from transformers.utils import is_torch_npu_available
 | 
			
		||||
 | 
			
		||||
from ..extras.constants import LLAMABOARD_CONFIG, MULTIMODAL_SUPPORTED_MODELS, PEFT_METHODS, TRAINING_STAGES
 | 
			
		||||
from ..extras.misc import is_accelerator_available, torch_gc, use_ray
 | 
			
		||||
from ..extras.misc import is_accelerator_available, torch_gc
 | 
			
		||||
from ..extras.packages import is_gradio_available
 | 
			
		||||
from .common import (
 | 
			
		||||
    DEFAULT_CACHE_DIR,
 | 
			
		||||
@ -114,7 +113,7 @@ class Runner:
 | 
			
		||||
 | 
			
		||||
        return ""
 | 
			
		||||
 | 
			
		||||
    def _finalize(self, lang: str, finish_info: str) -> str:
 | 
			
		||||
    def _finalize(self, lang: str, finish_info: str) -> None:
 | 
			
		||||
        r"""Clean the cached memory and resets the runner."""
 | 
			
		||||
        finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
 | 
			
		||||
        gr.Info(finish_info)
 | 
			
		||||
@ -123,7 +122,6 @@ class Runner:
 | 
			
		||||
        self.running = False
 | 
			
		||||
        self.running_data = None
 | 
			
		||||
        torch_gc()
 | 
			
		||||
        return finish_info
 | 
			
		||||
 | 
			
		||||
    def _parse_train_args(self, data: dict["Component", Any]) -> dict[str, Any]:
 | 
			
		||||
        r"""Build and validate the training arguments."""
 | 
			
		||||
@ -314,11 +312,13 @@ class Runner:
 | 
			
		||||
            max_samples=int(get("eval.max_samples")),
 | 
			
		||||
            per_device_eval_batch_size=get("eval.batch_size"),
 | 
			
		||||
            predict_with_generate=True,
 | 
			
		||||
            report_to="none",
 | 
			
		||||
            max_new_tokens=get("eval.max_new_tokens"),
 | 
			
		||||
            top_p=get("eval.top_p"),
 | 
			
		||||
            temperature=get("eval.temperature"),
 | 
			
		||||
            output_dir=get_save_dir(model_name, finetuning_type, get("eval.output_dir")),
 | 
			
		||||
            trust_remote_code=True,
 | 
			
		||||
            ddp_timeout=180000000,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if get("eval.predict"):
 | 
			
		||||
@ -375,7 +375,7 @@ class Runner:
 | 
			
		||||
                env["FORCE_TORCHRUN"] = "1"
 | 
			
		||||
 | 
			
		||||
            # NOTE: DO NOT USE shell=True to avoid security risk
 | 
			
		||||
            self.trainer = Popen(["llamafactory-cli", "train", save_cmd(args)], env=env)
 | 
			
		||||
            self.trainer = Popen(["llamafactory-cli", "train", save_cmd(args)], env=env, stderr=PIPE, text=True)
 | 
			
		||||
            yield from self.monitor()
 | 
			
		||||
 | 
			
		||||
    def _build_config_dict(self, data: dict["Component", Any]) -> dict[str, Any]:
 | 
			
		||||
@ -417,7 +417,8 @@ class Runner:
 | 
			
		||||
        swanlab_link = self.manager.get_elem_by_id("train.swanlab_link") if self.do_train else None
 | 
			
		||||
 | 
			
		||||
        running_log = ""
 | 
			
		||||
        while self.trainer is not None:
 | 
			
		||||
        return_code = -1
 | 
			
		||||
        while return_code == -1:
 | 
			
		||||
            if self.aborted:
 | 
			
		||||
                yield {
 | 
			
		||||
                    output_box: ALERTS["info_aborting"][lang],
 | 
			
		||||
@ -436,27 +437,26 @@ class Runner:
 | 
			
		||||
                    return_dict[swanlab_link] = running_info["swanlab_link"]
 | 
			
		||||
 | 
			
		||||
                yield return_dict
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                self.trainer.wait(2)
 | 
			
		||||
                self.trainer = None
 | 
			
		||||
                stderr = self.trainer.communicate(timeout=2)[1]
 | 
			
		||||
                return_code = self.trainer.returncode
 | 
			
		||||
            except TimeoutExpired:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
        if self.do_train:
 | 
			
		||||
            if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)) or use_ray():
 | 
			
		||||
                finish_info = ALERTS["info_finished"][lang]
 | 
			
		||||
        if return_code == 0:
 | 
			
		||||
            finish_info = ALERTS["info_finished"][lang]
 | 
			
		||||
            if self.do_train:
 | 
			
		||||
                finish_log = ALERTS["info_finished"][lang] + "\n\n" + running_log
 | 
			
		||||
            else:
 | 
			
		||||
                finish_info = ALERTS["err_failed"][lang]
 | 
			
		||||
                finish_log = load_eval_results(os.path.join(output_path, "all_results.json")) + "\n\n" + running_log
 | 
			
		||||
        else:
 | 
			
		||||
            if os.path.exists(os.path.join(output_path, "all_results.json")) or use_ray():
 | 
			
		||||
                finish_info = load_eval_results(os.path.join(output_path, "all_results.json"))
 | 
			
		||||
            else:
 | 
			
		||||
                finish_info = ALERTS["err_failed"][lang]
 | 
			
		||||
            print(stderr)
 | 
			
		||||
            finish_info = ALERTS["err_failed"][lang]
 | 
			
		||||
            finish_log = ALERTS["err_failed"][lang] + f" Exit code: {return_code}\n\n```\n{stderr}\n```\n"
 | 
			
		||||
 | 
			
		||||
        return_dict = {
 | 
			
		||||
            output_box: self._finalize(lang, finish_info) + "\n\n" + running_log,
 | 
			
		||||
            progress_bar: gr.Slider(visible=False),
 | 
			
		||||
        }
 | 
			
		||||
        self._finalize(lang, finish_info)
 | 
			
		||||
        return_dict = {output_box: finish_log, progress_bar: gr.Slider(visible=False)}
 | 
			
		||||
        yield return_dict
 | 
			
		||||
 | 
			
		||||
    def save_args(self, data):
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user