mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
fix eval in webui
This commit is contained in:
@@ -5,7 +5,7 @@ import signal
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import transformers
|
||||
from transformers import TrainerCallback
|
||||
@@ -38,8 +38,20 @@ class FixValueHeadModelCallback(TrainerCallback):
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
def __init__(self, output_dir: str) -> None:
|
||||
r"""
|
||||
Initializes a callback for logging training and evaluation status.
|
||||
"""
|
||||
""" Progress """
|
||||
self.start_time = 0
|
||||
self.cur_steps = 0
|
||||
self.max_steps = 0
|
||||
self.elapsed_time = ""
|
||||
self.remaining_time = ""
|
||||
self.thread_pool: Optional["ThreadPoolExecutor"] = None
|
||||
""" Status """
|
||||
self.aborted = False
|
||||
self.do_train = False
|
||||
""" Web UI """
|
||||
self.webui_mode = bool(int(os.environ.get("LLAMABOARD_ENABLED", "0")))
|
||||
if self.webui_mode:
|
||||
signal.signal(signal.SIGABRT, self._set_abort)
|
||||
@@ -66,6 +78,19 @@ class LogCallback(TrainerCallback):
|
||||
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
|
||||
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
|
||||
|
||||
def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None:
|
||||
with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(logs) + "\n")
|
||||
|
||||
def _create_thread_pool(self, output_dir: str) -> None:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
def _close_thread_pool(self) -> None:
|
||||
if self.thread_pool is not None:
|
||||
self.thread_pool.shutdown(wait=True)
|
||||
self.thread_pool = None
|
||||
|
||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the beginning of training.
|
||||
@@ -73,8 +98,7 @@ class LogCallback(TrainerCallback):
|
||||
if args.should_save:
|
||||
self.do_train = True
|
||||
self._reset(max_steps=state.max_steps)
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
||||
self._create_thread_pool(output_dir=args.output_dir)
|
||||
|
||||
if (
|
||||
args.should_save
|
||||
@@ -84,6 +108,12 @@ class LogCallback(TrainerCallback):
|
||||
logger.warning("Previous trainer log in this folder will be deleted.")
|
||||
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
|
||||
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
"""
|
||||
self._close_thread_pool()
|
||||
|
||||
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of an substep during gradient accumulation.
|
||||
@@ -103,31 +133,19 @@ class LogCallback(TrainerCallback):
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
Event called after an evaluation phase.
|
||||
"""
|
||||
self.thread_pool.shutdown(wait=True)
|
||||
self.thread_pool = None
|
||||
self._close_thread_pool()
|
||||
|
||||
def on_prediction_step(
|
||||
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||
):
|
||||
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after a prediction step.
|
||||
Event called after a successful prediction.
|
||||
"""
|
||||
eval_dataloader = kwargs.pop("eval_dataloader", None)
|
||||
if args.should_save and has_length(eval_dataloader) and not self.do_train:
|
||||
if self.max_steps == 0:
|
||||
self.max_steps = len(eval_dataloader)
|
||||
self._close_thread_pool()
|
||||
|
||||
self._timing(cur_steps=self.cur_steps + 1)
|
||||
|
||||
def _write_log(self, output_dir: str, logs: Dict[str, Any]):
|
||||
with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(logs) + "\n")
|
||||
|
||||
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
|
||||
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after logging the last logs.
|
||||
"""
|
||||
@@ -158,3 +176,26 @@ class LogCallback(TrainerCallback):
|
||||
|
||||
if self.thread_pool is not None:
|
||||
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
||||
|
||||
def on_prediction_step(
|
||||
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||
):
|
||||
r"""
|
||||
Event called after a prediction step.
|
||||
"""
|
||||
eval_dataloader = kwargs.pop("eval_dataloader", None)
|
||||
if args.should_save and has_length(eval_dataloader) and not self.do_train:
|
||||
if self.max_steps == 0:
|
||||
self._reset(max_steps=len(eval_dataloader))
|
||||
self._create_thread_pool(output_dir=args.output_dir)
|
||||
|
||||
self._timing(cur_steps=self.cur_steps + 1)
|
||||
if self.cur_steps % 5 == 0 and self.thread_pool is not None:
|
||||
logs = dict(
|
||||
current_steps=self.cur_steps,
|
||||
total_steps=self.max_steps,
|
||||
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
|
||||
elapsed_time=self.elapsed_time,
|
||||
remaining_time=self.remaining_time,
|
||||
)
|
||||
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
||||
|
||||
Reference in New Issue
Block a user