mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-09-12 16:12:48 +08:00
64 lines
2.5 KiB
Python
64 lines
2.5 KiB
Python
import os
|
|
import json
|
|
import time
|
|
from datetime import timedelta
|
|
|
|
from transformers import (
|
|
TrainerCallback,
|
|
TrainerControl,
|
|
TrainerState,
|
|
TrainingArguments
|
|
)
|
|
|
|
|
|
class LogCallback(TrainerCallback):
|
|
|
|
def __init__(self, runner=None):
|
|
self.runner = runner
|
|
self.start_time = time.time()
|
|
self.tracker = {}
|
|
|
|
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
|
r"""
|
|
Event called at the beginning of a training step. If using gradient accumulation, one training step
|
|
might take several inputs.
|
|
"""
|
|
if self.runner is not None and self.runner.aborted:
|
|
control.should_epoch_stop = True
|
|
control.should_training_stop = True
|
|
|
|
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
|
r"""
|
|
Event called at the end of an substep during gradient accumulation.
|
|
"""
|
|
if self.runner is not None and self.runner.aborted:
|
|
control.should_epoch_stop = True
|
|
control.should_training_stop = True
|
|
|
|
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
|
|
r"""
|
|
Event called after logging the last logs.
|
|
"""
|
|
if "loss" not in state.log_history[-1]:
|
|
return
|
|
cur_time = time.time()
|
|
cur_steps = state.log_history[-1].get("step")
|
|
elapsed_time = cur_time - self.start_time
|
|
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
|
|
remaining_steps = state.max_steps - cur_steps
|
|
remaining_time = remaining_steps * avg_time_per_step
|
|
self.tracker = {
|
|
"current_steps": cur_steps,
|
|
"total_steps": state.max_steps,
|
|
"loss": state.log_history[-1].get("loss", None),
|
|
"reward": state.log_history[-1].get("reward", None),
|
|
"learning_rate": state.log_history[-1].get("learning_rate", None),
|
|
"epoch": state.log_history[-1].get("epoch", None),
|
|
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
|
|
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
|
|
"remaining_time": str(timedelta(seconds=int(remaining_time)))
|
|
}
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
|
f.write(json.dumps(self.tracker) + "\n")
|