[callback] add torch profiler callback (#10463)

This commit is contained in:
浮梦
2026-05-20 20:47:52 +08:00
committed by GitHub
parent 8b5ea65770
commit 368c48968f
4 changed files with 320 additions and 6 deletions

View File

@@ -47,7 +47,13 @@ logger = logging.get_logger(__name__)
check_dependencies() check_dependencies()
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments] _TRAIN_ARGS = [
ModelArguments,
DataArguments,
TrainingArguments,
FinetuningArguments,
GeneratingArguments,
]
_TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments] _TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
@@ -57,9 +63,19 @@ _EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, Finetuning
if is_mcore_adapter_available() and is_env_enabled("USE_MCA"): if is_mcore_adapter_available() and is_env_enabled("USE_MCA"):
from mcore_adapter import TrainingArguments as McaTrainingArguments from mcore_adapter import TrainingArguments as McaTrainingArguments
_TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments] _TRAIN_MCA_ARGS = [
ModelArguments,
DataArguments,
McaTrainingArguments,
FinetuningArguments,
GeneratingArguments,
]
_TRAIN_MCA_CLS = tuple[ _TRAIN_MCA_CLS = tuple[
ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments ModelArguments,
DataArguments,
McaTrainingArguments,
FinetuningArguments,
GeneratingArguments,
] ]
else: else:
_TRAIN_MCA_ARGS = [] _TRAIN_MCA_ARGS = []

View File

@@ -14,6 +14,7 @@
import json import json
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from transformers.training_args import _convert_str_dict from transformers.training_args import _convert_str_dict
@@ -63,6 +64,58 @@ class RayArguments:
self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs)) self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs))
@dataclass
class ProfilerArguments:
r"""Arguments for torch profiler configuration."""
enable_torch_profiler: bool = field(
default=False,
metadata={"help": "Whether to enable torch profiler for collecting performance traces."},
)
profiler_output_dir: Optional[str] = field(
default=None,
metadata={"help": "Directory to write profiler traces. Defaults to <output_dir>/profiler if not set."},
)
profiler_wait_steps: int = field(
default=1,
metadata={"help": "Number of steps to skip at the start of each profiling cycle."},
)
profiler_warmup_steps: int = field(
default=1,
metadata={"help": "Number of profiler warm-up steps per cycle."},
)
profiler_active_steps: int = field(
default=1,
metadata={"help": "Number of steps to actively record per cycle."},
)
profiler_repeat: int = field(
default=1,
metadata={"help": "Number of profiling cycles. Set to 0 for continuous profiling."},
)
profiler_record_shapes: bool = field(
default=True,
metadata={"help": "Whether to record tensor shapes during profiling."},
)
profiler_profile_memory: bool = field(
default=True,
metadata={"help": "Whether to profile memory usage."},
)
profiler_with_stack: bool = field(
default=True,
metadata={"help": "Whether to record stack traces during profiling."},
)
profile_modules: Optional[str] = field(
default=None,
metadata={
"help": (
"Comma-separated list of module name patterns to profile with CUDA events. "
"Supports fnmatch wildcards (e.g. 'model.layers.0.self_attn,model.layers.*.mlp'). "
"Reports per-module forward/backward timing statistics at each logging step."
)
},
)
@dataclass @dataclass
class Fp8Arguments: class Fp8Arguments:
r"""Arguments pertaining to the FP8 training.""" r"""Arguments pertaining to the FP8 training."""
@@ -87,7 +140,7 @@ class Fp8Arguments:
@dataclass @dataclass
class TrainingArguments(Fp8Arguments, RayArguments, BaseTrainingArguments): class TrainingArguments(ProfilerArguments, Fp8Arguments, RayArguments, BaseTrainingArguments):
r"""Arguments pertaining to the trainer.""" r"""Arguments pertaining to the trainer."""
overwrite_output_dir: bool = field( overwrite_output_dir: bool = field(

View File

@@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import fnmatch
import json import json
import os import os
import signal import signal
import sys import sys
import time import time
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from datetime import timedelta from datetime import timedelta
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
@@ -31,7 +34,7 @@ from typing_extensions import override
from ..extras import logging from ..extras import logging
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import get_peak_memory, is_env_enabled, use_ray from ..extras.misc import get_peak_memory, is_env_enabled, is_torch_cuda_available, is_torch_npu_available, use_ray
from ..extras.packages import is_safetensors_available from ..extras.packages import is_safetensors_available
@@ -338,6 +341,96 @@ class LogCallback(TrainerCallback):
self.thread_pool.submit(self._write_log, args.output_dir, logs) self.thread_pool.submit(self._write_log, args.output_dir, logs)
class TorchProfilerCallback(TrainerCallback):
r"""A callback for collecting torch.profiler traces during training.
Activated by setting ``enable_torch_profiler: true`` in the YAML config.
Configuration fields (in YAML):
profiler_output_dir where to write traces (default: <output_dir>/profiler)
profiler_wait_steps steps to skip at start of each cycle (default: 1)
profiler_warmup_steps profiler warm-up steps per cycle (default: 1)
profiler_active_steps steps to record per cycle (default: 1)
profiler_repeat number of cycles; 0 = forever (default: 1)
profiler_record_shapes record tensor shapes (default: true)
profiler_profile_memory profile memory usage (default: true)
profiler_with_stack record stack traces (default: true)
Trace files (one per rank, Chrome / TensorBoard JSON format) are written to
``<profiler_output_dir>/rank_<N>/``.
"""
def __init__(self, training_args: "TrainingArguments") -> None:
self.profiler = None
self.profiler_args = training_args
@staticmethod
def _get_rank() -> int:
import torch.distributed as dist
if dist.is_available() and dist.is_initialized():
return dist.get_rank()
return 0
@override
def on_train_begin(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
) -> None:
if self.profiler is not None:
self.profiler.stop()
self.profiler = None
pa = self.profiler_args
output_dir = pa.profiler_output_dir or os.path.join(args.output_dir, "profiler")
rank = self._get_rank()
trace_dir = os.path.join(output_dir, f"rank_{rank}")
os.makedirs(trace_dir, exist_ok=True)
activities = [torch.profiler.ProfilerActivity.CPU]
try:
if is_torch_cuda_available():
activities.append(torch.profiler.ProfilerActivity.CUDA)
if is_torch_npu_available():
activities.append(torch.profiler.ProfilerActivity.NPU)
except Exception:
pass
self.profiler = torch.profiler.profile(
activities=activities,
schedule=torch.profiler.schedule(
wait=pa.profiler_wait_steps,
warmup=pa.profiler_warmup_steps,
active=pa.profiler_active_steps,
repeat=pa.profiler_repeat,
),
on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_dir),
record_shapes=pa.profiler_record_shapes,
profile_memory=pa.profiler_profile_memory,
with_stack=pa.profiler_with_stack,
)
self.profiler.start()
logger.info_rank0(
f"TorchProfiler started — schedule: wait={pa.profiler_wait_steps}, warmup={pa.profiler_warmup_steps}, "
f"active={pa.profiler_active_steps}, repeat={pa.profiler_repeat}. Traces → {output_dir}"
)
@override
def on_step_end(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
) -> None:
if self.profiler is not None:
self.profiler.step()
@override
def on_train_end(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
) -> None:
if self.profiler is not None:
self.profiler.stop()
self.profiler = None
logger.info_rank0("TorchProfiler stopped.")
class ReporterCallback(TrainerCallback): class ReporterCallback(TrainerCallback):
r"""A callback for reporting training status to external logger.""" r"""A callback for reporting training status to external logger."""
@@ -394,3 +487,143 @@ class ReporterCallback(TrainerCallback):
"generating_args": self.generating_args.to_dict(), "generating_args": self.generating_args.to_dict(),
} }
) )
class ModuleProfilerCallback(TrainerCallback):
r"""Profile forward/backward time of specified modules using accelerator events.
Hooks are registered on modules matching the user-provided name patterns.
Timing statistics are logged at each trainer logging step.
Usage in YAML config:
profile_modules: "*.layers.0.self_attn,*.layers.0.mlp"
Supports fnmatch wildcards:
profile_modules: "*.layers.*.self_attn,*.layers.*.mlp.experts"
"""
@staticmethod
def _get_accelerator():
"""Detect available accelerator and return (event_factory, synchronize_fn)."""
if is_torch_cuda_available():
return torch.cuda.Event, torch.cuda.synchronize
if is_torch_npu_available():
return torch.npu.Event, torch.npu.synchronize
return None, None
def __init__(self, profile_modules: str) -> None:
self.patterns = [p.strip() for p in profile_modules.split(",") if p.strip()]
self._create_event, self._synchronize = self._get_accelerator()
self._handles: list[Any] = []
self._forward_times: dict[str, list[float]] = defaultdict(list)
self._backward_times: dict[str, list[float]] = defaultdict(list)
self._pending_forward: dict[str, tuple] = {}
self._pending_backward: dict[str, tuple] = {}
@property
def enabled(self) -> bool:
return self._create_event is not None
def _match(self, name: str) -> bool:
return any(fnmatch.fnmatch(name, pat) for pat in self.patterns)
def _make_forward_pre_hook(self, name: str):
def hook(module, input):
start = self._create_event(enable_timing=True)
end = self._create_event(enable_timing=True)
start.record()
self._pending_forward[name] = (start, end)
return hook
def _make_forward_hook(self, name: str):
def hook(module, input, output):
pair = self._pending_forward.get(name)
if pair is not None:
pair[1].record()
return hook
def _make_backward_pre_hook(self, name: str):
def hook(module, grad_output):
start = self._create_event(enable_timing=True)
end = self._create_event(enable_timing=True)
start.record()
self._pending_backward[name] = (start, end)
return hook
def _make_backward_hook(self, name: str):
def hook(module, grad_input, grad_output):
pair = self._pending_backward.get(name)
if pair is not None:
pair[1].record()
return hook
@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if not self.enabled:
logger.warning_rank0("ModuleProfiler: no supported accelerator (CUDA/NPU) found, profiling disabled.")
return
model = kwargs.get("model")
if model is None:
return
matched = []
for name, module in model.named_modules():
if not name or not self._match(name):
continue
self._handles.append(module.register_forward_pre_hook(self._make_forward_pre_hook(name)))
self._handles.append(module.register_forward_hook(self._make_forward_hook(name)))
self._handles.append(module.register_full_backward_pre_hook(self._make_backward_pre_hook(name)))
self._handles.append(module.register_full_backward_hook(self._make_backward_hook(name)))
matched.append(name)
if matched:
logger.info_rank0(
f"ModuleProfiler: registered hooks on {len(matched)} modules: {matched[:5]}"
+ (f" ... (+{len(matched)-5} more)" if len(matched) > 5 else "")
)
else:
logger.warning_rank0(f"ModuleProfiler: no modules matched patterns {self.patterns}")
@override
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if not self.enabled:
return
self._synchronize()
for name, (start, end) in self._pending_forward.items():
self._forward_times[name].append(start.elapsed_time(end))
self._pending_forward.clear()
for name, (start, end) in self._pending_backward.items():
self._backward_times[name].append(start.elapsed_time(end))
self._pending_backward.clear()
@override
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if not self._forward_times and not self._backward_times:
return
lines = ["[ModuleProfiler] Timing (ms):"]
all_names = sorted(set(list(self._forward_times.keys()) + list(self._backward_times.keys())))
for name in all_names:
fwd = self._forward_times.get(name, [])
bwd = self._backward_times.get(name, [])
fwd_mean = sum(fwd) / len(fwd) if fwd else 0.0
bwd_mean = sum(bwd) / len(bwd) if bwd else 0.0
lines.append(f" {name}: fwd={fwd_mean:.3f}, bwd={bwd_mean:.3f}, total={fwd_mean+bwd_mean:.3f}")
logger.info_rank0("\n".join(lines))
self._forward_times.clear()
self._backward_times.clear()
@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
for handle in self._handles:
handle.remove()
self._handles.clear()

View File

@@ -32,7 +32,13 @@ from ..extras.packages import (
) )
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback from .callbacks import (
LogCallback,
ModuleProfilerCallback,
PissaConvertCallback,
ReporterCallback,
TorchProfilerCallback,
)
from .dpo import run_dpo from .dpo import run_dpo
from .kto import run_kto from .kto import run_kto
from .ppo import run_ppo from .ppo import run_ppo
@@ -74,6 +80,12 @@ def _training_function(config: dict[str, Any]) -> None:
if finetuning_args.early_stopping_steps is not None: if finetuning_args.early_stopping_steps is not None:
callbacks.append(EarlyStoppingCallback(early_stopping_patience=finetuning_args.early_stopping_steps)) callbacks.append(EarlyStoppingCallback(early_stopping_patience=finetuning_args.early_stopping_steps))
if training_args.enable_torch_profiler:
callbacks.append(TorchProfilerCallback(training_args))
if training_args.profile_modules:
callbacks.append(ModuleProfilerCallback(training_args.profile_modules))
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel: if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel: