mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-04-28 02:39:03 +08:00
[v1] add callbacks (#10255)
This commit is contained in:
@@ -36,6 +36,12 @@ from ..accelerator.helper import ReduceOp
|
||||
from ..accelerator.interface import Dim, DistributedInterface
|
||||
from ..config import TrainingArguments
|
||||
from ..utils import logging
|
||||
from ..utils.callbacks import (
|
||||
CallbackHandler,
|
||||
LoggingCallback,
|
||||
TrainerCallback,
|
||||
TrainerState,
|
||||
)
|
||||
from ..utils.helper import compute_valid_tokens
|
||||
from ..utils.types import BatchInput, HFModel, ModelOutput, Tensor, TorchDataset
|
||||
from .utils.batching import BatchGenerator
|
||||
@@ -52,6 +58,7 @@ class BaseTrainer:
|
||||
model: HFModel,
|
||||
renderer: Renderer,
|
||||
train_dataset: TorchDataset,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
) -> None:
|
||||
self.args = args
|
||||
self.model = model
|
||||
@@ -99,6 +106,14 @@ class BaseTrainer:
|
||||
self._init_optimizer()
|
||||
self._init_lr_scheduler()
|
||||
|
||||
# Callbacks
|
||||
self.callback_handler = CallbackHandler([LoggingCallback()], trainer=self)
|
||||
for cb in callbacks or []:
|
||||
self.callback_handler.add_callback(cb)
|
||||
|
||||
# Callbacks: TrainerState tracks progress across the full run.
|
||||
self.state = TrainerState(num_training_steps=self.num_training_steps)
|
||||
|
||||
def _create_batch_generator(self) -> None:
|
||||
self.train_batch_generator = BatchGenerator(
|
||||
dataset=self.train_dataset,
|
||||
@@ -174,10 +189,18 @@ class BaseTrainer:
|
||||
def fit(self) -> None:
|
||||
"""Train the model."""
|
||||
self.model.train()
|
||||
self.callback_handler.on_train_begin(self.args, self.state)
|
||||
for epoch in range(self.args.num_train_epochs):
|
||||
self.state.epoch = epoch
|
||||
self.train_batch_generator.set_epoch(epoch)
|
||||
self.callback_handler.on_epoch_begin(self.args, self.state)
|
||||
|
||||
for micro_batches in self.train_batch_generator:
|
||||
self.global_step += 1
|
||||
|
||||
self.state.global_step = self.global_step
|
||||
self.callback_handler.on_step_begin(self.args, self.state)
|
||||
|
||||
step_loss = 0
|
||||
step_valid_tokens = compute_valid_tokens(micro_batches)
|
||||
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
|
||||
@@ -213,14 +236,41 @@ class BaseTrainer:
|
||||
|
||||
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
|
||||
DistributedInterface().sync()
|
||||
if DistributedInterface().get_rank() == 0:
|
||||
print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}")
|
||||
|
||||
# Update state with step metrics
|
||||
current_lr = (
|
||||
self.lr_scheduler.get_last_lr()[0]
|
||||
if hasattr(self.lr_scheduler, "get_last_lr")
|
||||
else self.args.learning_rate
|
||||
)
|
||||
self.state.loss = step_loss
|
||||
self.state.grad_norm = grad_norm
|
||||
self.state.learning_rate = current_lr
|
||||
|
||||
self.callback_handler.on_step_end(self.args, self.state)
|
||||
|
||||
# Logging: trainer decides when to log
|
||||
if self.global_step % self.args.logging_steps == 0:
|
||||
logs = {
|
||||
"epoch": epoch,
|
||||
"step": self.global_step,
|
||||
"loss": step_loss,
|
||||
"grad_norm": grad_norm,
|
||||
"learning_rate": current_lr,
|
||||
}
|
||||
self.callback_handler.on_log(self.args, self.state, logs)
|
||||
|
||||
# Check if max_steps is reached
|
||||
if self.global_step >= self.num_training_steps:
|
||||
logger.info_rank0(f"Reached max_steps ({self.num_training_steps}), stopping training.")
|
||||
self.callback_handler.on_epoch_end(self.args, self.state)
|
||||
self.callback_handler.on_train_end(self.args, self.state)
|
||||
return
|
||||
|
||||
self.callback_handler.on_epoch_end(self.args, self.state)
|
||||
|
||||
self.callback_handler.on_train_end(self.args, self.state)
|
||||
|
||||
def save_model(self) -> None:
|
||||
"""Save the model."""
|
||||
if self.args.dist_config is not None and self.args.dist_config.name in ("deepspeed", "fsdp2"):
|
||||
@@ -234,3 +284,5 @@ class BaseTrainer:
|
||||
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
||||
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
||||
logger.info_rank0(f"Model saved to {self.args.output_dir}")
|
||||
|
||||
self.callback_handler.on_save(self.args, self.state)
|
||||
|
||||
Reference in New Issue
Block a user