mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-09-11 23:52:50 +08:00
214 lines
8.1 KiB
Python
214 lines
8.1 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
import signal
|
|
import sys
|
|
import time
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from datetime import timedelta
|
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
|
|
|
import transformers
|
|
from transformers import TrainerCallback
|
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
|
|
|
|
from .constants import TRAINER_LOG
|
|
from .logging import LoggerHandler, get_logger
|
|
from .misc import fix_valuehead_checkpoint
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from transformers import TrainerControl, TrainerState, TrainingArguments
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class FixValueHeadModelCallback(TrainerCallback):
|
|
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
|
r"""
|
|
Event called after a checkpoint save.
|
|
"""
|
|
if args.should_save:
|
|
fix_valuehead_checkpoint(
|
|
model=kwargs.pop("model"),
|
|
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
|
|
safe_serialization=args.save_safetensors,
|
|
)
|
|
|
|
|
|
class LogCallback(TrainerCallback):
|
|
def __init__(self, output_dir: str) -> None:
|
|
r"""
|
|
Initializes a callback for logging training and evaluation status.
|
|
"""
|
|
""" Progress """
|
|
self.start_time = 0
|
|
self.cur_steps = 0
|
|
self.max_steps = 0
|
|
self.elapsed_time = ""
|
|
self.remaining_time = ""
|
|
self.thread_pool: Optional["ThreadPoolExecutor"] = None
|
|
""" Status """
|
|
self.aborted = False
|
|
self.do_train = False
|
|
""" Web UI """
|
|
self.webui_mode = bool(int(os.environ.get("LLAMABOARD_ENABLED", "0")))
|
|
if self.webui_mode:
|
|
signal.signal(signal.SIGABRT, self._set_abort)
|
|
self.logger_handler = LoggerHandler(output_dir)
|
|
logging.root.addHandler(self.logger_handler)
|
|
transformers.logging.add_handler(self.logger_handler)
|
|
|
|
def _set_abort(self, signum, frame) -> None:
|
|
self.aborted = True
|
|
|
|
def _reset(self, max_steps: int = 0) -> None:
|
|
self.start_time = time.time()
|
|
self.cur_steps = 0
|
|
self.max_steps = max_steps
|
|
self.elapsed_time = ""
|
|
self.remaining_time = ""
|
|
|
|
def _timing(self, cur_steps: int) -> None:
|
|
cur_time = time.time()
|
|
elapsed_time = cur_time - self.start_time
|
|
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
|
|
remaining_time = (self.max_steps - cur_steps) * avg_time_per_step
|
|
self.cur_steps = cur_steps
|
|
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
|
|
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
|
|
|
|
def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None:
|
|
with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
|
|
f.write(json.dumps(logs) + "\n")
|
|
|
|
def _create_thread_pool(self, output_dir: str) -> None:
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
|
|
|
def _close_thread_pool(self) -> None:
|
|
if self.thread_pool is not None:
|
|
self.thread_pool.shutdown(wait=True)
|
|
self.thread_pool = None
|
|
|
|
def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
|
r"""
|
|
Event called at the end of the initialization of the `Trainer`.
|
|
"""
|
|
if (
|
|
args.should_save
|
|
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
|
|
and args.overwrite_output_dir
|
|
):
|
|
logger.warning("Previous trainer log in this folder will be deleted.")
|
|
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
|
|
|
|
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
|
r"""
|
|
Event called at the beginning of training.
|
|
"""
|
|
if args.should_save:
|
|
self.do_train = True
|
|
self._reset(max_steps=state.max_steps)
|
|
self._create_thread_pool(output_dir=args.output_dir)
|
|
|
|
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
|
r"""
|
|
Event called at the end of training.
|
|
"""
|
|
self._close_thread_pool()
|
|
|
|
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
|
r"""
|
|
Event called at the end of an substep during gradient accumulation.
|
|
"""
|
|
if self.aborted:
|
|
control.should_epoch_stop = True
|
|
control.should_training_stop = True
|
|
|
|
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
|
r"""
|
|
Event called at the end of a training step.
|
|
"""
|
|
if self.aborted:
|
|
control.should_epoch_stop = True
|
|
control.should_training_stop = True
|
|
|
|
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
|
r"""
|
|
Event called after an evaluation phase.
|
|
"""
|
|
self._close_thread_pool()
|
|
|
|
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
|
r"""
|
|
Event called after a successful prediction.
|
|
"""
|
|
self._close_thread_pool()
|
|
|
|
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
|
r"""
|
|
Event called after logging the last logs.
|
|
"""
|
|
if not args.should_save:
|
|
return
|
|
|
|
self._timing(cur_steps=state.global_step)
|
|
logs = dict(
|
|
current_steps=self.cur_steps,
|
|
total_steps=self.max_steps,
|
|
loss=state.log_history[-1].get("loss", None),
|
|
eval_loss=state.log_history[-1].get("eval_loss", None),
|
|
predict_loss=state.log_history[-1].get("predict_loss", None),
|
|
reward=state.log_history[-1].get("reward", None),
|
|
accuracy=state.log_history[-1].get("rewards/accuracies", None),
|
|
learning_rate=state.log_history[-1].get("learning_rate", None),
|
|
epoch=state.log_history[-1].get("epoch", None),
|
|
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
|
|
elapsed_time=self.elapsed_time,
|
|
remaining_time=self.remaining_time,
|
|
)
|
|
logs = {k: v for k, v in logs.items() if v is not None}
|
|
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]):
|
|
logger.info(
|
|
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
|
|
logs["loss"], logs["learning_rate"], logs["epoch"]
|
|
)
|
|
)
|
|
|
|
if self.thread_pool is not None:
|
|
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
|
|
|
def on_prediction_step(
|
|
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
|
):
|
|
r"""
|
|
Event called after a prediction step.
|
|
"""
|
|
if self.do_train:
|
|
return
|
|
|
|
if self.aborted:
|
|
sys.exit(0)
|
|
|
|
if not args.should_save:
|
|
return
|
|
|
|
eval_dataloader = kwargs.pop("eval_dataloader", None)
|
|
if has_length(eval_dataloader):
|
|
if self.max_steps == 0:
|
|
self._reset(max_steps=len(eval_dataloader))
|
|
self._create_thread_pool(output_dir=args.output_dir)
|
|
|
|
self._timing(cur_steps=self.cur_steps + 1)
|
|
if self.cur_steps % 5 == 0 and self.thread_pool is not None:
|
|
logs = dict(
|
|
current_steps=self.cur_steps,
|
|
total_steps=self.max_steps,
|
|
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
|
|
elapsed_time=self.elapsed_time,
|
|
remaining_time=self.remaining_time,
|
|
)
|
|
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|