mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 02:00:36 +08:00
[webui] upgrade webui and fix api (#8460)
This commit is contained in:
@@ -16,14 +16,13 @@ import json
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from subprocess import Popen, TimeoutExpired
|
||||
from subprocess import PIPE, Popen, TimeoutExpired
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.utils import is_torch_npu_available
|
||||
|
||||
from ..extras.constants import LLAMABOARD_CONFIG, MULTIMODAL_SUPPORTED_MODELS, PEFT_METHODS, TRAINING_STAGES
|
||||
from ..extras.misc import is_accelerator_available, torch_gc, use_ray
|
||||
from ..extras.misc import is_accelerator_available, torch_gc
|
||||
from ..extras.packages import is_gradio_available
|
||||
from .common import (
|
||||
DEFAULT_CACHE_DIR,
|
||||
@@ -114,7 +113,7 @@ class Runner:
|
||||
|
||||
return ""
|
||||
|
||||
def _finalize(self, lang: str, finish_info: str) -> str:
|
||||
def _finalize(self, lang: str, finish_info: str) -> None:
|
||||
r"""Clean the cached memory and resets the runner."""
|
||||
finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
|
||||
gr.Info(finish_info)
|
||||
@@ -123,7 +122,6 @@ class Runner:
|
||||
self.running = False
|
||||
self.running_data = None
|
||||
torch_gc()
|
||||
return finish_info
|
||||
|
||||
def _parse_train_args(self, data: dict["Component", Any]) -> dict[str, Any]:
|
||||
r"""Build and validate the training arguments."""
|
||||
@@ -314,11 +312,13 @@ class Runner:
|
||||
max_samples=int(get("eval.max_samples")),
|
||||
per_device_eval_batch_size=get("eval.batch_size"),
|
||||
predict_with_generate=True,
|
||||
report_to="none",
|
||||
max_new_tokens=get("eval.max_new_tokens"),
|
||||
top_p=get("eval.top_p"),
|
||||
temperature=get("eval.temperature"),
|
||||
output_dir=get_save_dir(model_name, finetuning_type, get("eval.output_dir")),
|
||||
trust_remote_code=True,
|
||||
ddp_timeout=180000000,
|
||||
)
|
||||
|
||||
if get("eval.predict"):
|
||||
@@ -375,7 +375,7 @@ class Runner:
|
||||
env["FORCE_TORCHRUN"] = "1"
|
||||
|
||||
# NOTE: DO NOT USE shell=True to avoid security risk
|
||||
self.trainer = Popen(["llamafactory-cli", "train", save_cmd(args)], env=env)
|
||||
self.trainer = Popen(["llamafactory-cli", "train", save_cmd(args)], env=env, stderr=PIPE, text=True)
|
||||
yield from self.monitor()
|
||||
|
||||
def _build_config_dict(self, data: dict["Component", Any]) -> dict[str, Any]:
|
||||
@@ -417,7 +417,8 @@ class Runner:
|
||||
swanlab_link = self.manager.get_elem_by_id("train.swanlab_link") if self.do_train else None
|
||||
|
||||
running_log = ""
|
||||
while self.trainer is not None:
|
||||
return_code = -1
|
||||
while return_code == -1:
|
||||
if self.aborted:
|
||||
yield {
|
||||
output_box: ALERTS["info_aborting"][lang],
|
||||
@@ -436,27 +437,26 @@ class Runner:
|
||||
return_dict[swanlab_link] = running_info["swanlab_link"]
|
||||
|
||||
yield return_dict
|
||||
|
||||
try:
|
||||
self.trainer.wait(2)
|
||||
self.trainer = None
|
||||
stderr = self.trainer.communicate(timeout=2)[1]
|
||||
return_code = self.trainer.returncode
|
||||
except TimeoutExpired:
|
||||
continue
|
||||
|
||||
if self.do_train:
|
||||
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)) or use_ray():
|
||||
finish_info = ALERTS["info_finished"][lang]
|
||||
if return_code == 0:
|
||||
finish_info = ALERTS["info_finished"][lang]
|
||||
if self.do_train:
|
||||
finish_log = ALERTS["info_finished"][lang] + "\n\n" + running_log
|
||||
else:
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
finish_log = load_eval_results(os.path.join(output_path, "all_results.json")) + "\n\n" + running_log
|
||||
else:
|
||||
if os.path.exists(os.path.join(output_path, "all_results.json")) or use_ray():
|
||||
finish_info = load_eval_results(os.path.join(output_path, "all_results.json"))
|
||||
else:
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
print(stderr)
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
finish_log = ALERTS["err_failed"][lang] + f" Exit code: {return_code}\n\n```\n{stderr}\n```\n"
|
||||
|
||||
return_dict = {
|
||||
output_box: self._finalize(lang, finish_info) + "\n\n" + running_log,
|
||||
progress_bar: gr.Slider(visible=False),
|
||||
}
|
||||
self._finalize(lang, finish_info)
|
||||
return_dict = {output_box: finish_log, progress_bar: gr.Slider(visible=False)}
|
||||
yield return_dict
|
||||
|
||||
def save_args(self, data):
|
||||
|
||||
Reference in New Issue
Block a user