[callback] add torch profiler callback (#10463)

This commit is contained in:
浮梦
2026-05-20 20:47:52 +08:00
committed by GitHub
parent 8b5ea65770
commit 368c48968f
4 changed files with 320 additions and 6 deletions

View File

@@ -32,7 +32,13 @@ from ..extras.packages import (
)
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
from .callbacks import (
LogCallback,
ModuleProfilerCallback,
PissaConvertCallback,
ReporterCallback,
TorchProfilerCallback,
)
from .dpo import run_dpo
from .kto import run_kto
from .ppo import run_ppo
@@ -74,6 +80,12 @@ def _training_function(config: dict[str, Any]) -> None:
if finetuning_args.early_stopping_steps is not None:
callbacks.append(EarlyStoppingCallback(early_stopping_patience=finetuning_args.early_stopping_steps))
if training_args.enable_torch_profiler:
callbacks.append(TorchProfilerCallback(training_args))
if training_args.profile_modules:
callbacks.append(ModuleProfilerCallback(training_args.profile_modules))
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel: