update webui and add CLIs

This commit is contained in:
hiyouga
2024-05-03 02:58:23 +08:00
parent 39e964a97a
commit 245fe47ece
65 changed files with 363 additions and 372 deletions

View File

@@ -1,14 +1,18 @@
import json
import logging
import os
import signal
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Dict
import transformers
from transformers import TrainerCallback
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
from .constants import LOG_FILE_NAME
from .logging import get_logger
from .constants import TRAINER_LOG
from .logging import LoggerHandler, get_logger
from .misc import fix_valuehead_checkpoint
@@ -33,20 +37,32 @@ class FixValueHeadModelCallback(TrainerCallback):
class LogCallback(TrainerCallback):
def __init__(self, runner=None):
self.runner = runner
self.in_training = False
def __init__(self, output_dir: str) -> None:
self.aborted = False
self.do_train = False
self.webui_mode = bool(int(os.environ.get("LLAMABOARD_ENABLED", "0")))
if self.webui_mode:
signal.signal(signal.SIGABRT, self._set_abort)
self.logger_handler = LoggerHandler(output_dir)
logging.root.addHandler(self.logger_handler)
transformers.logging.add_handler(self.logger_handler)
def _set_abort(self, signum, frame) -> None:
self.aborted = True
def _reset(self, max_steps: int = 0) -> None:
self.start_time = time.time()
self.cur_steps = 0
self.max_steps = 0
self.max_steps = max_steps
self.elapsed_time = ""
self.remaining_time = ""
def timing(self):
def _timing(self, cur_steps: int) -> None:
cur_time = time.time()
elapsed_time = cur_time - self.start_time
avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0
remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
remaining_time = (self.max_steps - cur_steps) * avg_time_per_step
self.cur_steps = cur_steps
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
@@ -54,36 +70,27 @@ class LogCallback(TrainerCallback):
r"""
Event called at the beginning of training.
"""
if state.is_local_process_zero:
self.in_training = True
self.start_time = time.time()
self.max_steps = state.max_steps
if args.should_log:
self.do_train = True
self._reset(max_steps=state.max_steps)
if args.save_on_each_node:
if not state.is_local_process_zero:
return
else:
if not state.is_world_process_zero:
return
if args.should_save:
os.makedirs(args.output_dir, exist_ok=True)
self.thread_pool = ThreadPoolExecutor(max_workers=1)
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
if state.is_local_process_zero:
self.in_training = False
self.cur_steps = 0
self.max_steps = 0
if (
args.should_save
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
and args.overwrite_output_dir
):
logger.warning("Previous trainer log in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of an substep during gradient accumulation.
"""
if state.is_local_process_zero and self.runner is not None and self.runner.aborted:
if self.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
@@ -91,42 +98,41 @@ class LogCallback(TrainerCallback):
r"""
Event called at the end of a training step.
"""
if state.is_local_process_zero:
self.cur_steps = state.global_step
self.timing()
if self.runner is not None and self.runner.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
if args.should_log:
self._timing(cur_steps=state.global_step)
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if self.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after an evaluation phase.
Event called at the end of training.
"""
if state.is_local_process_zero and not self.in_training:
self.cur_steps = 0
self.max_steps = 0
self.thread_pool.shutdown(wait=True)
self.thread_pool = None
def on_predict(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs
def on_prediction_step(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
):
r"""
Event called after a successful prediction.
Event called after a prediction step.
"""
if state.is_local_process_zero and not self.in_training:
self.cur_steps = 0
self.max_steps = 0
eval_dataloader = kwargs.pop("eval_dataloader", None)
if args.should_log and has_length(eval_dataloader) and not self.do_train:
if self.max_steps == 0:
self.max_steps = len(eval_dataloader)
self._timing(cur_steps=self.cur_steps + 1)
def _write_log(self, output_dir: str, logs: Dict[str, Any]):
with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
f.write(json.dumps(logs) + "\n")
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
r"""
Event called after logging the last logs.
Event called after logging the last logs, `args.should_log` has been applied.
"""
if args.save_on_each_node:
if not state.is_local_process_zero:
return
else:
if not state.is_world_process_zero:
return
logs = dict(
current_steps=self.cur_steps,
total_steps=self.max_steps,
@@ -141,26 +147,13 @@ class LogCallback(TrainerCallback):
elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time,
)
if self.runner is not None:
logs = {k: v for k, v in logs.items() if v is not None}
if self.webui_mode and "loss" in logs and "learning_rate" in logs and "epoch" in logs:
logger.info(
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
logs["loss"], logs["learning_rate"], logs["epoch"]
)
)
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
f.write(json.dumps(logs) + "\n")
def on_prediction_step(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
):
r"""
Event called after a prediction step.
"""
eval_dataloader = kwargs.pop("eval_dataloader", None)
if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training:
if self.max_steps == 0:
self.max_steps = len(eval_dataloader)
self.cur_steps += 1
self.timing()
if args.should_save and self.thread_pool is not None:
self.thread_pool.submit(self._write_log, args.output_dir, logs)

View File

@@ -24,8 +24,6 @@ IGNORE_INDEX = -100
LAYERNORM_NAMES = {"norm", "ln"}
LOG_FILE_NAME = "trainer_log.jsonl"
METHODS = ["full", "freeze", "lora"]
MLLM_LIST = ["LLaVA1.5"]
@@ -34,10 +32,16 @@ MOD_SUPPORTED_MODELS = ["bloom", "falcon", "gemma", "llama", "mistral", "mixtral
PEFT_METHODS = ["lora"]
RUNNING_LOG = "running_log.txt"
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
SUPPORTED_MODELS = OrderedDict()
TRAINER_CONFIG = "trainer_config.yaml"
TRAINER_LOG = "trainer_log.jsonl"
TRAINING_STAGES = {
"Supervised Fine-Tuning": "sft",
"Reward Modeling": "rm",

View File

@@ -1,5 +1,9 @@
import logging
import os
import sys
from concurrent.futures import ThreadPoolExecutor
from .constants import RUNNING_LOG
class LoggerHandler(logging.Handler):
@@ -7,19 +11,35 @@ class LoggerHandler(logging.Handler):
Logger handler used in Web UI.
"""
def __init__(self):
def __init__(self, output_dir: str) -> None:
super().__init__()
self.log = ""
formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
)
self.setLevel(logging.INFO)
self.setFormatter(formatter)
def reset(self):
self.log = ""
os.makedirs(output_dir, exist_ok=True)
self.running_log = os.path.join(output_dir, RUNNING_LOG)
if os.path.exists(self.running_log):
os.remove(self.running_log)
def emit(self, record):
self.thread_pool = ThreadPoolExecutor(max_workers=1)
def _write_log(self, log_entry: str) -> None:
with open(self.running_log, "a", encoding="utf-8") as f:
f.write(log_entry + "\n\n")
def emit(self, record) -> None:
if record.name == "httpx":
return
log_entry = self.format(record)
self.log += log_entry
self.log += "\n\n"
self.thread_pool.submit(self._write_log, log_entry)
def close(self) -> None:
self.thread_pool.shutdown(wait=True)
return super().close()
def get_logger(name: str) -> logging.Logger:

View File

@@ -1,7 +1,7 @@
import json
import math
import os
from typing import List
from typing import Any, Dict, List
from transformers.trainer import TRAINER_STATE_NAME
@@ -10,6 +10,7 @@ from .packages import is_matplotlib_available
if is_matplotlib_available():
import matplotlib.figure
import matplotlib.pyplot as plt
@@ -21,7 +22,7 @@ def smooth(scalars: List[float]) -> List[float]:
EMA implementation according to TensorBoard.
"""
last = scalars[0]
smoothed = list()
smoothed = []
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
for next_val in scalars:
smoothed_val = last * weight + (1 - weight) * next_val
@@ -30,7 +31,27 @@ def smooth(scalars: List[float]) -> List[float]:
return smoothed
def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure":
plt.close("all")
plt.switch_backend("agg")
fig = plt.figure()
ax = fig.add_subplot(111)
steps, losses = [], []
for log in trainer_log:
if log.get("loss", None):
steps.append(log["current_steps"])
losses.append(log["loss"])
ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
ax.legend()
ax.set_xlabel("step")
ax.set_ylabel("loss")
return fig
def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
plt.switch_backend("agg")
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
data = json.load(f)