From 368c48968f66d7cf26cb27e2af133d911c75a585 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B5=AE=E6=A2=A6?= <46097299+frozenleaves@users.noreply.github.com> Date: Wed, 20 May 2026 20:47:52 +0800 Subject: [PATCH] [callback] add torch profiler callback (#10463) --- src/llamafactory/hparams/parser.py | 22 +- src/llamafactory/hparams/training_args.py | 55 ++++- src/llamafactory/train/callbacks.py | 235 +++++++++++++++++++++- src/llamafactory/train/tuner.py | 14 +- 4 files changed, 320 insertions(+), 6 deletions(-) diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index c946cceb9..7c8e3b3fa 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -47,7 +47,13 @@ logger = logging.get_logger(__name__) check_dependencies() -_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments] +_TRAIN_ARGS = [ + ModelArguments, + DataArguments, + TrainingArguments, + FinetuningArguments, + GeneratingArguments, +] _TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments] _INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] @@ -57,9 +63,19 @@ _EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, Finetuning if is_mcore_adapter_available() and is_env_enabled("USE_MCA"): from mcore_adapter import TrainingArguments as McaTrainingArguments - _TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments] + _TRAIN_MCA_ARGS = [ + ModelArguments, + DataArguments, + McaTrainingArguments, + FinetuningArguments, + GeneratingArguments, + ] _TRAIN_MCA_CLS = tuple[ - ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments + ModelArguments, + DataArguments, + McaTrainingArguments, + FinetuningArguments, + GeneratingArguments, ] else: _TRAIN_MCA_ARGS = [] diff --git a/src/llamafactory/hparams/training_args.py b/src/llamafactory/hparams/training_args.py index bf975f0d2..84f4d8e6c 100644 --- a/src/llamafactory/hparams/training_args.py +++ b/src/llamafactory/hparams/training_args.py @@ -14,6 +14,7 @@ import json from dataclasses import dataclass, field +from typing import Optional from transformers import Seq2SeqTrainingArguments from transformers.training_args import _convert_str_dict @@ -63,6 +64,58 @@ class RayArguments: self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs)) +@dataclass +class ProfilerArguments: + r"""Arguments for torch profiler configuration.""" + + enable_torch_profiler: bool = field( + default=False, + metadata={"help": "Whether to enable torch profiler for collecting performance traces."}, + ) + profiler_output_dir: Optional[str] = field( + default=None, + metadata={"help": "Directory to write profiler traces. Defaults to /profiler if not set."}, + ) + profiler_wait_steps: int = field( + default=1, + metadata={"help": "Number of steps to skip at the start of each profiling cycle."}, + ) + profiler_warmup_steps: int = field( + default=1, + metadata={"help": "Number of profiler warm-up steps per cycle."}, + ) + profiler_active_steps: int = field( + default=1, + metadata={"help": "Number of steps to actively record per cycle."}, + ) + profiler_repeat: int = field( + default=1, + metadata={"help": "Number of profiling cycles. Set to 0 for continuous profiling."}, + ) + profiler_record_shapes: bool = field( + default=True, + metadata={"help": "Whether to record tensor shapes during profiling."}, + ) + profiler_profile_memory: bool = field( + default=True, + metadata={"help": "Whether to profile memory usage."}, + ) + profiler_with_stack: bool = field( + default=True, + metadata={"help": "Whether to record stack traces during profiling."}, + ) + profile_modules: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Comma-separated list of module name patterns to profile with CUDA events. " + "Supports fnmatch wildcards (e.g. 'model.layers.0.self_attn,model.layers.*.mlp'). " + "Reports per-module forward/backward timing statistics at each logging step." + ) + }, + ) + + @dataclass class Fp8Arguments: r"""Arguments pertaining to the FP8 training.""" @@ -87,7 +140,7 @@ class Fp8Arguments: @dataclass -class TrainingArguments(Fp8Arguments, RayArguments, BaseTrainingArguments): +class TrainingArguments(ProfilerArguments, Fp8Arguments, RayArguments, BaseTrainingArguments): r"""Arguments pertaining to the trainer.""" overwrite_output_dir: bool = field( diff --git a/src/llamafactory/train/callbacks.py b/src/llamafactory/train/callbacks.py index 77507c848..3dc7fd730 100644 --- a/src/llamafactory/train/callbacks.py +++ b/src/llamafactory/train/callbacks.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import fnmatch import json import os import signal import sys import time +from collections import defaultdict from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field from datetime import timedelta from typing import TYPE_CHECKING, Any, Optional @@ -31,7 +34,7 @@ from typing_extensions import override from ..extras import logging from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME -from ..extras.misc import get_peak_memory, is_env_enabled, use_ray +from ..extras.misc import get_peak_memory, is_env_enabled, is_torch_cuda_available, is_torch_npu_available, use_ray from ..extras.packages import is_safetensors_available @@ -338,6 +341,96 @@ class LogCallback(TrainerCallback): self.thread_pool.submit(self._write_log, args.output_dir, logs) +class TorchProfilerCallback(TrainerCallback): + r"""A callback for collecting torch.profiler traces during training. + + Activated by setting ``enable_torch_profiler: true`` in the YAML config. + + Configuration fields (in YAML): + profiler_output_dir – where to write traces (default: /profiler) + profiler_wait_steps – steps to skip at start of each cycle (default: 1) + profiler_warmup_steps – profiler warm-up steps per cycle (default: 1) + profiler_active_steps – steps to record per cycle (default: 1) + profiler_repeat – number of cycles; 0 = forever (default: 1) + profiler_record_shapes – record tensor shapes (default: true) + profiler_profile_memory – profile memory usage (default: true) + profiler_with_stack – record stack traces (default: true) + + Trace files (one per rank, Chrome / TensorBoard JSON format) are written to + ``/rank_/``. + """ + + def __init__(self, training_args: "TrainingArguments") -> None: + self.profiler = None + self.profiler_args = training_args + + @staticmethod + def _get_rank() -> int: + import torch.distributed as dist + + if dist.is_available() and dist.is_initialized(): + return dist.get_rank() + return 0 + + @override + def on_train_begin( + self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs + ) -> None: + if self.profiler is not None: + self.profiler.stop() + self.profiler = None + + pa = self.profiler_args + output_dir = pa.profiler_output_dir or os.path.join(args.output_dir, "profiler") + rank = self._get_rank() + trace_dir = os.path.join(output_dir, f"rank_{rank}") + os.makedirs(trace_dir, exist_ok=True) + + activities = [torch.profiler.ProfilerActivity.CPU] + try: + if is_torch_cuda_available(): + activities.append(torch.profiler.ProfilerActivity.CUDA) + if is_torch_npu_available(): + activities.append(torch.profiler.ProfilerActivity.NPU) + except Exception: + pass + + self.profiler = torch.profiler.profile( + activities=activities, + schedule=torch.profiler.schedule( + wait=pa.profiler_wait_steps, + warmup=pa.profiler_warmup_steps, + active=pa.profiler_active_steps, + repeat=pa.profiler_repeat, + ), + on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_dir), + record_shapes=pa.profiler_record_shapes, + profile_memory=pa.profiler_profile_memory, + with_stack=pa.profiler_with_stack, + ) + self.profiler.start() + logger.info_rank0( + f"TorchProfiler started — schedule: wait={pa.profiler_wait_steps}, warmup={pa.profiler_warmup_steps}, " + f"active={pa.profiler_active_steps}, repeat={pa.profiler_repeat}. Traces → {output_dir}" + ) + + @override + def on_step_end( + self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs + ) -> None: + if self.profiler is not None: + self.profiler.step() + + @override + def on_train_end( + self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs + ) -> None: + if self.profiler is not None: + self.profiler.stop() + self.profiler = None + logger.info_rank0("TorchProfiler stopped.") + + class ReporterCallback(TrainerCallback): r"""A callback for reporting training status to external logger.""" @@ -394,3 +487,143 @@ class ReporterCallback(TrainerCallback): "generating_args": self.generating_args.to_dict(), } ) + + +class ModuleProfilerCallback(TrainerCallback): + r"""Profile forward/backward time of specified modules using accelerator events. + + Hooks are registered on modules matching the user-provided name patterns. + Timing statistics are logged at each trainer logging step. + + Usage in YAML config: + profile_modules: "*.layers.0.self_attn,*.layers.0.mlp" + + Supports fnmatch wildcards: + profile_modules: "*.layers.*.self_attn,*.layers.*.mlp.experts" + """ + + @staticmethod + def _get_accelerator(): + """Detect available accelerator and return (event_factory, synchronize_fn).""" + if is_torch_cuda_available(): + return torch.cuda.Event, torch.cuda.synchronize + if is_torch_npu_available(): + return torch.npu.Event, torch.npu.synchronize + return None, None + + def __init__(self, profile_modules: str) -> None: + self.patterns = [p.strip() for p in profile_modules.split(",") if p.strip()] + self._create_event, self._synchronize = self._get_accelerator() + self._handles: list[Any] = [] + self._forward_times: dict[str, list[float]] = defaultdict(list) + self._backward_times: dict[str, list[float]] = defaultdict(list) + self._pending_forward: dict[str, tuple] = {} + self._pending_backward: dict[str, tuple] = {} + + @property + def enabled(self) -> bool: + return self._create_event is not None + + def _match(self, name: str) -> bool: + return any(fnmatch.fnmatch(name, pat) for pat in self.patterns) + + def _make_forward_pre_hook(self, name: str): + def hook(module, input): + start = self._create_event(enable_timing=True) + end = self._create_event(enable_timing=True) + start.record() + self._pending_forward[name] = (start, end) + + return hook + + def _make_forward_hook(self, name: str): + def hook(module, input, output): + pair = self._pending_forward.get(name) + if pair is not None: + pair[1].record() + + return hook + + def _make_backward_pre_hook(self, name: str): + def hook(module, grad_output): + start = self._create_event(enable_timing=True) + end = self._create_event(enable_timing=True) + start.record() + self._pending_backward[name] = (start, end) + + return hook + + def _make_backward_hook(self, name: str): + def hook(module, grad_input, grad_output): + pair = self._pending_backward.get(name) + if pair is not None: + pair[1].record() + + return hook + + @override + def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + if not self.enabled: + logger.warning_rank0("ModuleProfiler: no supported accelerator (CUDA/NPU) found, profiling disabled.") + return + + model = kwargs.get("model") + if model is None: + return + + matched = [] + for name, module in model.named_modules(): + if not name or not self._match(name): + continue + self._handles.append(module.register_forward_pre_hook(self._make_forward_pre_hook(name))) + self._handles.append(module.register_forward_hook(self._make_forward_hook(name))) + self._handles.append(module.register_full_backward_pre_hook(self._make_backward_pre_hook(name))) + self._handles.append(module.register_full_backward_hook(self._make_backward_hook(name))) + matched.append(name) + + if matched: + logger.info_rank0( + f"ModuleProfiler: registered hooks on {len(matched)} modules: {matched[:5]}" + + (f" ... (+{len(matched)-5} more)" if len(matched) > 5 else "") + ) + else: + logger.warning_rank0(f"ModuleProfiler: no modules matched patterns {self.patterns}") + + @override + def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + if not self.enabled: + return + + self._synchronize() + + for name, (start, end) in self._pending_forward.items(): + self._forward_times[name].append(start.elapsed_time(end)) + self._pending_forward.clear() + + for name, (start, end) in self._pending_backward.items(): + self._backward_times[name].append(start.elapsed_time(end)) + self._pending_backward.clear() + + @override + def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + if not self._forward_times and not self._backward_times: + return + + lines = ["[ModuleProfiler] Timing (ms):"] + all_names = sorted(set(list(self._forward_times.keys()) + list(self._backward_times.keys()))) + for name in all_names: + fwd = self._forward_times.get(name, []) + bwd = self._backward_times.get(name, []) + fwd_mean = sum(fwd) / len(fwd) if fwd else 0.0 + bwd_mean = sum(bwd) / len(bwd) if bwd else 0.0 + lines.append(f" {name}: fwd={fwd_mean:.3f}, bwd={bwd_mean:.3f}, total={fwd_mean+bwd_mean:.3f}") + + logger.info_rank0("\n".join(lines)) + self._forward_times.clear() + self._backward_times.clear() + + @override + def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + for handle in self._handles: + handle.remove() + self._handles.clear() diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 480ac19de..dcde974b7 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -32,7 +32,13 @@ from ..extras.packages import ( ) from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args from ..model import load_model, load_tokenizer -from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback +from .callbacks import ( + LogCallback, + ModuleProfilerCallback, + PissaConvertCallback, + ReporterCallback, + TorchProfilerCallback, +) from .dpo import run_dpo from .kto import run_kto from .ppo import run_ppo @@ -74,6 +80,12 @@ def _training_function(config: dict[str, Any]) -> None: if finetuning_args.early_stopping_steps is not None: callbacks.append(EarlyStoppingCallback(early_stopping_patience=finetuning_args.early_stopping_steps)) + if training_args.enable_torch_profiler: + callbacks.append(TorchProfilerCallback(training_args)) + + if training_args.profile_modules: + callbacks.append(ModuleProfilerCallback(training_args.profile_modules)) + callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel: