mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-28 02:48:54 +08:00
[callback] add torch profiler callback (#10463)
This commit is contained in:
@@ -47,7 +47,13 @@ logger = logging.get_logger(__name__)
|
||||
check_dependencies()
|
||||
|
||||
|
||||
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_ARGS = [
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
TrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
]
|
||||
_TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
@@ -57,9 +63,19 @@ _EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, Finetuning
|
||||
if is_mcore_adapter_available() and is_env_enabled("USE_MCA"):
|
||||
from mcore_adapter import TrainingArguments as McaTrainingArguments
|
||||
|
||||
_TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_MCA_ARGS = [
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
McaTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
]
|
||||
_TRAIN_MCA_CLS = tuple[
|
||||
ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
McaTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
]
|
||||
else:
|
||||
_TRAIN_MCA_ARGS = []
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.training_args import _convert_str_dict
|
||||
@@ -63,6 +64,58 @@ class RayArguments:
|
||||
self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfilerArguments:
|
||||
r"""Arguments for torch profiler configuration."""
|
||||
|
||||
enable_torch_profiler: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to enable torch profiler for collecting performance traces."},
|
||||
)
|
||||
profiler_output_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Directory to write profiler traces. Defaults to <output_dir>/profiler if not set."},
|
||||
)
|
||||
profiler_wait_steps: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of steps to skip at the start of each profiling cycle."},
|
||||
)
|
||||
profiler_warmup_steps: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of profiler warm-up steps per cycle."},
|
||||
)
|
||||
profiler_active_steps: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of steps to actively record per cycle."},
|
||||
)
|
||||
profiler_repeat: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of profiling cycles. Set to 0 for continuous profiling."},
|
||||
)
|
||||
profiler_record_shapes: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to record tensor shapes during profiling."},
|
||||
)
|
||||
profiler_profile_memory: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to profile memory usage."},
|
||||
)
|
||||
profiler_with_stack: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to record stack traces during profiling."},
|
||||
)
|
||||
profile_modules: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Comma-separated list of module name patterns to profile with CUDA events. "
|
||||
"Supports fnmatch wildcards (e.g. 'model.layers.0.self_attn,model.layers.*.mlp'). "
|
||||
"Reports per-module forward/backward timing statistics at each logging step."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Fp8Arguments:
|
||||
r"""Arguments pertaining to the FP8 training."""
|
||||
@@ -87,7 +140,7 @@ class Fp8Arguments:
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(Fp8Arguments, RayArguments, BaseTrainingArguments):
|
||||
class TrainingArguments(ProfilerArguments, Fp8Arguments, RayArguments, BaseTrainingArguments):
|
||||
r"""Arguments pertaining to the trainer."""
|
||||
|
||||
overwrite_output_dir: bool = field(
|
||||
|
||||
@@ -12,12 +12,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import fnmatch
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
@@ -31,7 +34,7 @@ from typing_extensions import override
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.misc import get_peak_memory, is_env_enabled, use_ray
|
||||
from ..extras.misc import get_peak_memory, is_env_enabled, is_torch_cuda_available, is_torch_npu_available, use_ray
|
||||
from ..extras.packages import is_safetensors_available
|
||||
|
||||
|
||||
@@ -338,6 +341,96 @@ class LogCallback(TrainerCallback):
|
||||
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
||||
|
||||
|
||||
class TorchProfilerCallback(TrainerCallback):
|
||||
r"""A callback for collecting torch.profiler traces during training.
|
||||
|
||||
Activated by setting ``enable_torch_profiler: true`` in the YAML config.
|
||||
|
||||
Configuration fields (in YAML):
|
||||
profiler_output_dir – where to write traces (default: <output_dir>/profiler)
|
||||
profiler_wait_steps – steps to skip at start of each cycle (default: 1)
|
||||
profiler_warmup_steps – profiler warm-up steps per cycle (default: 1)
|
||||
profiler_active_steps – steps to record per cycle (default: 1)
|
||||
profiler_repeat – number of cycles; 0 = forever (default: 1)
|
||||
profiler_record_shapes – record tensor shapes (default: true)
|
||||
profiler_profile_memory – profile memory usage (default: true)
|
||||
profiler_with_stack – record stack traces (default: true)
|
||||
|
||||
Trace files (one per rank, Chrome / TensorBoard JSON format) are written to
|
||||
``<profiler_output_dir>/rank_<N>/``.
|
||||
"""
|
||||
|
||||
def __init__(self, training_args: "TrainingArguments") -> None:
|
||||
self.profiler = None
|
||||
self.profiler_args = training_args
|
||||
|
||||
@staticmethod
|
||||
def _get_rank() -> int:
|
||||
import torch.distributed as dist
|
||||
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
return dist.get_rank()
|
||||
return 0
|
||||
|
||||
@override
|
||||
def on_train_begin(
|
||||
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||
) -> None:
|
||||
if self.profiler is not None:
|
||||
self.profiler.stop()
|
||||
self.profiler = None
|
||||
|
||||
pa = self.profiler_args
|
||||
output_dir = pa.profiler_output_dir or os.path.join(args.output_dir, "profiler")
|
||||
rank = self._get_rank()
|
||||
trace_dir = os.path.join(output_dir, f"rank_{rank}")
|
||||
os.makedirs(trace_dir, exist_ok=True)
|
||||
|
||||
activities = [torch.profiler.ProfilerActivity.CPU]
|
||||
try:
|
||||
if is_torch_cuda_available():
|
||||
activities.append(torch.profiler.ProfilerActivity.CUDA)
|
||||
if is_torch_npu_available():
|
||||
activities.append(torch.profiler.ProfilerActivity.NPU)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=activities,
|
||||
schedule=torch.profiler.schedule(
|
||||
wait=pa.profiler_wait_steps,
|
||||
warmup=pa.profiler_warmup_steps,
|
||||
active=pa.profiler_active_steps,
|
||||
repeat=pa.profiler_repeat,
|
||||
),
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_dir),
|
||||
record_shapes=pa.profiler_record_shapes,
|
||||
profile_memory=pa.profiler_profile_memory,
|
||||
with_stack=pa.profiler_with_stack,
|
||||
)
|
||||
self.profiler.start()
|
||||
logger.info_rank0(
|
||||
f"TorchProfiler started — schedule: wait={pa.profiler_wait_steps}, warmup={pa.profiler_warmup_steps}, "
|
||||
f"active={pa.profiler_active_steps}, repeat={pa.profiler_repeat}. Traces → {output_dir}"
|
||||
)
|
||||
|
||||
@override
|
||||
def on_step_end(
|
||||
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||
) -> None:
|
||||
if self.profiler is not None:
|
||||
self.profiler.step()
|
||||
|
||||
@override
|
||||
def on_train_end(
|
||||
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||
) -> None:
|
||||
if self.profiler is not None:
|
||||
self.profiler.stop()
|
||||
self.profiler = None
|
||||
logger.info_rank0("TorchProfiler stopped.")
|
||||
|
||||
|
||||
class ReporterCallback(TrainerCallback):
|
||||
r"""A callback for reporting training status to external logger."""
|
||||
|
||||
@@ -394,3 +487,143 @@ class ReporterCallback(TrainerCallback):
|
||||
"generating_args": self.generating_args.to_dict(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ModuleProfilerCallback(TrainerCallback):
|
||||
r"""Profile forward/backward time of specified modules using accelerator events.
|
||||
|
||||
Hooks are registered on modules matching the user-provided name patterns.
|
||||
Timing statistics are logged at each trainer logging step.
|
||||
|
||||
Usage in YAML config:
|
||||
profile_modules: "*.layers.0.self_attn,*.layers.0.mlp"
|
||||
|
||||
Supports fnmatch wildcards:
|
||||
profile_modules: "*.layers.*.self_attn,*.layers.*.mlp.experts"
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _get_accelerator():
|
||||
"""Detect available accelerator and return (event_factory, synchronize_fn)."""
|
||||
if is_torch_cuda_available():
|
||||
return torch.cuda.Event, torch.cuda.synchronize
|
||||
if is_torch_npu_available():
|
||||
return torch.npu.Event, torch.npu.synchronize
|
||||
return None, None
|
||||
|
||||
def __init__(self, profile_modules: str) -> None:
|
||||
self.patterns = [p.strip() for p in profile_modules.split(",") if p.strip()]
|
||||
self._create_event, self._synchronize = self._get_accelerator()
|
||||
self._handles: list[Any] = []
|
||||
self._forward_times: dict[str, list[float]] = defaultdict(list)
|
||||
self._backward_times: dict[str, list[float]] = defaultdict(list)
|
||||
self._pending_forward: dict[str, tuple] = {}
|
||||
self._pending_backward: dict[str, tuple] = {}
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self._create_event is not None
|
||||
|
||||
def _match(self, name: str) -> bool:
|
||||
return any(fnmatch.fnmatch(name, pat) for pat in self.patterns)
|
||||
|
||||
def _make_forward_pre_hook(self, name: str):
|
||||
def hook(module, input):
|
||||
start = self._create_event(enable_timing=True)
|
||||
end = self._create_event(enable_timing=True)
|
||||
start.record()
|
||||
self._pending_forward[name] = (start, end)
|
||||
|
||||
return hook
|
||||
|
||||
def _make_forward_hook(self, name: str):
|
||||
def hook(module, input, output):
|
||||
pair = self._pending_forward.get(name)
|
||||
if pair is not None:
|
||||
pair[1].record()
|
||||
|
||||
return hook
|
||||
|
||||
def _make_backward_pre_hook(self, name: str):
|
||||
def hook(module, grad_output):
|
||||
start = self._create_event(enable_timing=True)
|
||||
end = self._create_event(enable_timing=True)
|
||||
start.record()
|
||||
self._pending_backward[name] = (start, end)
|
||||
|
||||
return hook
|
||||
|
||||
def _make_backward_hook(self, name: str):
|
||||
def hook(module, grad_input, grad_output):
|
||||
pair = self._pending_backward.get(name)
|
||||
if pair is not None:
|
||||
pair[1].record()
|
||||
|
||||
return hook
|
||||
|
||||
@override
|
||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
if not self.enabled:
|
||||
logger.warning_rank0("ModuleProfiler: no supported accelerator (CUDA/NPU) found, profiling disabled.")
|
||||
return
|
||||
|
||||
model = kwargs.get("model")
|
||||
if model is None:
|
||||
return
|
||||
|
||||
matched = []
|
||||
for name, module in model.named_modules():
|
||||
if not name or not self._match(name):
|
||||
continue
|
||||
self._handles.append(module.register_forward_pre_hook(self._make_forward_pre_hook(name)))
|
||||
self._handles.append(module.register_forward_hook(self._make_forward_hook(name)))
|
||||
self._handles.append(module.register_full_backward_pre_hook(self._make_backward_pre_hook(name)))
|
||||
self._handles.append(module.register_full_backward_hook(self._make_backward_hook(name)))
|
||||
matched.append(name)
|
||||
|
||||
if matched:
|
||||
logger.info_rank0(
|
||||
f"ModuleProfiler: registered hooks on {len(matched)} modules: {matched[:5]}"
|
||||
+ (f" ... (+{len(matched)-5} more)" if len(matched) > 5 else "")
|
||||
)
|
||||
else:
|
||||
logger.warning_rank0(f"ModuleProfiler: no modules matched patterns {self.patterns}")
|
||||
|
||||
@override
|
||||
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
self._synchronize()
|
||||
|
||||
for name, (start, end) in self._pending_forward.items():
|
||||
self._forward_times[name].append(start.elapsed_time(end))
|
||||
self._pending_forward.clear()
|
||||
|
||||
for name, (start, end) in self._pending_backward.items():
|
||||
self._backward_times[name].append(start.elapsed_time(end))
|
||||
self._pending_backward.clear()
|
||||
|
||||
@override
|
||||
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
if not self._forward_times and not self._backward_times:
|
||||
return
|
||||
|
||||
lines = ["[ModuleProfiler] Timing (ms):"]
|
||||
all_names = sorted(set(list(self._forward_times.keys()) + list(self._backward_times.keys())))
|
||||
for name in all_names:
|
||||
fwd = self._forward_times.get(name, [])
|
||||
bwd = self._backward_times.get(name, [])
|
||||
fwd_mean = sum(fwd) / len(fwd) if fwd else 0.0
|
||||
bwd_mean = sum(bwd) / len(bwd) if bwd else 0.0
|
||||
lines.append(f" {name}: fwd={fwd_mean:.3f}, bwd={bwd_mean:.3f}, total={fwd_mean+bwd_mean:.3f}")
|
||||
|
||||
logger.info_rank0("\n".join(lines))
|
||||
self._forward_times.clear()
|
||||
self._backward_times.clear()
|
||||
|
||||
@override
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
for handle in self._handles:
|
||||
handle.remove()
|
||||
self._handles.clear()
|
||||
|
||||
@@ -32,7 +32,13 @@ from ..extras.packages import (
|
||||
)
|
||||
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
||||
from .callbacks import (
|
||||
LogCallback,
|
||||
ModuleProfilerCallback,
|
||||
PissaConvertCallback,
|
||||
ReporterCallback,
|
||||
TorchProfilerCallback,
|
||||
)
|
||||
from .dpo import run_dpo
|
||||
from .kto import run_kto
|
||||
from .ppo import run_ppo
|
||||
@@ -74,6 +80,12 @@ def _training_function(config: dict[str, Any]) -> None:
|
||||
if finetuning_args.early_stopping_steps is not None:
|
||||
callbacks.append(EarlyStoppingCallback(early_stopping_patience=finetuning_args.early_stopping_steps))
|
||||
|
||||
if training_args.enable_torch_profiler:
|
||||
callbacks.append(TorchProfilerCallback(training_args))
|
||||
|
||||
if training_args.profile_modules:
|
||||
callbacks.append(ModuleProfilerCallback(training_args.profile_modules))
|
||||
|
||||
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
|
||||
|
||||
if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel:
|
||||
|
||||
Reference in New Issue
Block a user