[v1] add callbacks (#10255)

This commit is contained in:
jiaqiw09
2026-03-26 19:59:57 +08:00
committed by GitHub
parent 1e536733c6
commit c340aa2a33
5 changed files with 293 additions and 2 deletions

View File

@@ -36,6 +36,12 @@ from ..accelerator.helper import ReduceOp
from ..accelerator.interface import Dim, DistributedInterface
from ..config import TrainingArguments
from ..utils import logging
from ..utils.callbacks import (
CallbackHandler,
LoggingCallback,
TrainerCallback,
TrainerState,
)
from ..utils.helper import compute_valid_tokens
from ..utils.types import BatchInput, HFModel, ModelOutput, Tensor, TorchDataset
from .utils.batching import BatchGenerator
@@ -52,6 +58,7 @@ class BaseTrainer:
model: HFModel,
renderer: Renderer,
train_dataset: TorchDataset,
callbacks: list[TrainerCallback] | None = None,
) -> None:
self.args = args
self.model = model
@@ -99,6 +106,14 @@ class BaseTrainer:
self._init_optimizer()
self._init_lr_scheduler()
# Callbacks
self.callback_handler = CallbackHandler([LoggingCallback()], trainer=self)
for cb in callbacks or []:
self.callback_handler.add_callback(cb)
# Callbacks: TrainerState tracks progress across the full run.
self.state = TrainerState(num_training_steps=self.num_training_steps)
def _create_batch_generator(self) -> None:
self.train_batch_generator = BatchGenerator(
dataset=self.train_dataset,
@@ -174,10 +189,18 @@ class BaseTrainer:
def fit(self) -> None:
"""Train the model."""
self.model.train()
self.callback_handler.on_train_begin(self.args, self.state)
for epoch in range(self.args.num_train_epochs):
self.state.epoch = epoch
self.train_batch_generator.set_epoch(epoch)
self.callback_handler.on_epoch_begin(self.args, self.state)
for micro_batches in self.train_batch_generator:
self.global_step += 1
self.state.global_step = self.global_step
self.callback_handler.on_step_begin(self.args, self.state)
step_loss = 0
step_valid_tokens = compute_valid_tokens(micro_batches)
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
@@ -213,14 +236,41 @@ class BaseTrainer:
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
DistributedInterface().sync()
if DistributedInterface().get_rank() == 0:
print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}")
# Update state with step metrics
current_lr = (
self.lr_scheduler.get_last_lr()[0]
if hasattr(self.lr_scheduler, "get_last_lr")
else self.args.learning_rate
)
self.state.loss = step_loss
self.state.grad_norm = grad_norm
self.state.learning_rate = current_lr
self.callback_handler.on_step_end(self.args, self.state)
# Logging: trainer decides when to log
if self.global_step % self.args.logging_steps == 0:
logs = {
"epoch": epoch,
"step": self.global_step,
"loss": step_loss,
"grad_norm": grad_norm,
"learning_rate": current_lr,
}
self.callback_handler.on_log(self.args, self.state, logs)
# Check if max_steps is reached
if self.global_step >= self.num_training_steps:
logger.info_rank0(f"Reached max_steps ({self.num_training_steps}), stopping training.")
self.callback_handler.on_epoch_end(self.args, self.state)
self.callback_handler.on_train_end(self.args, self.state)
return
self.callback_handler.on_epoch_end(self.args, self.state)
self.callback_handler.on_train_end(self.args, self.state)
def save_model(self) -> None:
"""Save the model."""
if self.args.dist_config is not None and self.args.dist_config.name in ("deepspeed", "fsdp2"):
@@ -234,3 +284,5 @@ class BaseTrainer:
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB")
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
logger.info_rank0(f"Model saved to {self.args.output_dir}")
self.callback_handler.on_save(self.args, self.state)